-
Notifications
You must be signed in to change notification settings - Fork 162
[1/N] QATTrainer training workflow fixes and clean up; Added backend specific unitests; #318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/quantization/utils.py (1)
459-466
: Symmetric fix for restore path.Use the model‑aware name to match saved keys across wrapper changes.
Apply this diff:
- for name, module in model.named_modules(): - if isinstance(module, TensorQuantizer) and get_unwrapped_name(name) in quantizer_state_dict: - module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name)]) + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer): + key = get_unwrapped_name(name, model) + if key in quantizer_state_dict: + module.load_state_dict(quantizer_state_dict[key])examples/llm_qat/launch.sh (1)
126-129
: Division by zero when no GPUs detected.
DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
fails whenGPU_COUNT=0
(CPU‑only CI/dev). Guard it.Apply this diff:
-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +GPU_COUNT=$(python -c "import torch; n=torch.cuda.device_count(); print(n if n>0 else 1)") +# Calculate save_steps (fallback to 1 on CPU-only) +DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
🧹 Nitpick comments (10)
modelopt/torch/quantization/nn/modules/quant_module.py (1)
161-164
: Make context manager re-entrant-safe by restoring prior stateIf
quantize_weight()
is nested or called when_enable_weight_quantization
was alreadyTrue
, hard‑resetting toFalse
on exit can flip the state incorrectly. Preserve and restore the previous value.def quantize_weight(self): """Context in which `self.weight` is quantized.""" - self._enable_weight_quantization = True - try: - yield - finally: - self._enable_weight_quantization = False + prev = getattr(self, "_enable_weight_quantization", False) + self._enable_weight_quantization = True + try: + yield + finally: + self._enable_weight_quantization = prevexamples/llm_qat/utils.py (1)
172-175
: Perplexity value can be overwritten by existing key; compute-first merge reverses precedenceWith
{"perplexity": ..., **metrics}
, any preexistingmetrics["perplexity"]
overwrites the computed value. Flip the precedence (or assign in-place) and add a small guard.-def get_metrics_with_perplexity(metrics): - """Add perplexity to the metrics.""" - metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics} - return metrics +def get_metrics_with_perplexity(metrics): + """Add perplexity to the metrics.""" + loss = metrics.get("eval_loss", None) + if loss is not None: + # math.exp avoids unnecessary torch tensor creation on CPU + import math + metrics["perplexity"] = float(math.exp(float(loss))) + return metricstests/_test_utils/examples/run_command.py (1)
35-44
: Also set a sane default MASTER_ADDR when injecting MASTER_PORTSome launchers expect MASTER_ADDR; defaulting it to localhost avoids env-dependent failures. Keeping allocation/race risks low is fine for tests.
def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False): print(f"[{example_path}] Running command: {cmd_parts}") env = os.environ.copy() if setup_free_port: free_port = get_free_port() env["MASTER_PORT"] = str(free_port) + env.setdefault("MASTER_ADDR", "127.0.0.1") subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True)
examples/llm_qat/main.py (1)
266-269
: Guard eval-only printing against missing eval_lossIf
trainer.evaluate()
doesn’t returneval_loss
,get_metrics_with_perplexity
will no-op after the fix; keeping the current order is fine. Consider asserting presence during tests.Would you like a small unit test to assert the presence/shape of evaluation metrics across backends?
modelopt/torch/utils/network.py (2)
440-454
: Use isinstance for unwrapping to handle subclasses
type(model) in SUPPORTED_WRAPPERS
misses subclasses. Iterate withisinstance
for robustness.- if force_unwrap: - try: - if type(model) in SUPPORTED_WRAPPERS: - return getattr(model, SUPPORTED_WRAPPERS[type(model)]) + if force_unwrap: + try: + for wrapper_type, attr in SUPPORTED_WRAPPERS.items(): + if isinstance(model, wrapper_type): + return getattr(model, attr) except AttributeError: raise ValueError( f"Model of type {type(model)} could not be forcefully unwrapped! Please manually" " unwrap the model before passing it in." ) - if type(model) in SUPPORTED_WRAPPERS: + for wrapper_type, attr in SUPPORTED_WRAPPERS.items(): + if isinstance(model, wrapper_type): if raise_error: raise ValueError(msg or f"Model {model} is wrapped by {type(model)}!") elif warn: warnings.warn(msg or f"Model {model} is wrapped by {type(model)}; unwrapping...") - return getattr(model, SUPPORTED_WRAPPERS[type(model)]) + return getattr(model, attr) return model
599-612
: Also strip DataParallel’s 'module.' prefix in get_unwrapped_nameDP inserts the same prefix as DDP; include it in the check.
- if isinstance(model, nn.parallel.DistributedDataParallel) or ( + if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)) or ( DeepSpeedEngine is not None and isinstance(model, DeepSpeedEngine) ): name = name.removeprefix("module.")modelopt/torch/opt/plugins/peft.py (1)
84-95
: Avoid KeyError and unify with utils: use set_quantizer_state_dict.Direct indexing
quantizer_state_dict[get_unwrapped_name(name)]
can KeyError on naming mismatches. Prefer the helper which tolerates missing keys.Apply this diff:
- if os.path.isfile(_get_quantizer_state_save_path(model_id)): - from modelopt.torch.quantization.nn import TensorQuantizer - - quantizer_state_dict = torch.load( - _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False - ) - for name, module in self.named_modules(): - if isinstance(module, TensorQuantizer): - module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name)]) + if os.path.isfile(_get_quantizer_state_save_path(model_id)): + from modelopt.torch.quantization.utils import set_quantizer_state_dict + quantizer_state_dict = torch.load( + _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False + ) + set_quantizer_state_dict(self, quantizer_state_dict)examples/llm_qat/simple_qat_train.py (1)
124-125
: Resolve quant config from mtq.config to avoid AttributeError.Safer to fetch from
mtq.config
(choices originate there).Apply this diff:
- model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate) + cfg = getattr(mtq.config, args.quant_cfg) + model = mtq.quantize(model, cfg, calibrate)tests/examples/llm_qat/test_llm_qat.py (1)
39-45
: Parametrization includes deepspeed — gate if DS isn’t installed.Consider conditionally skipping the deepspeed case when
import deepspeed
fails to avoid infra‑dependent failures in CI.modelopt/torch/quantization/plugins/transformers_trainer.py (1)
134-143
: Quant config resolution: prefer mtq.config with fallback.Some configs live under
mtq.config
. Use that first, then fall back to root for re‑exports.Apply this diff:
- if quant_args is not None and getattr(quant_args, "quant_cfg", None): - quant_cfg = ( - getattr(mtq, quant_args.quant_cfg) - if isinstance(quant_args.quant_cfg, str) - else quant_args.quant_cfg - ) + if quant_args is not None and getattr(quant_args, "quant_cfg", None): + if isinstance(quant_args.quant_cfg, str): + quant_cfg = getattr(getattr(mtq, "config", mtq), quant_args.quant_cfg) + else: + quant_cfg = quant_args.quant_cfg
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
examples/llm_qat/accelerate_config/deepspeed.yaml
(1 hunks)examples/llm_qat/convert_sharded_ckpt.py
(0 hunks)examples/llm_qat/launch.sh
(4 hunks)examples/llm_qat/main.py
(2 hunks)examples/llm_qat/simple_qat_train.py
(2 hunks)examples/llm_qat/utils.py
(1 hunks)modelopt/torch/opt/conversion.py
(1 hunks)modelopt/torch/opt/plugins/peft.py
(1 hunks)modelopt/torch/quantization/nn/modules/quant_module.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)modelopt/torch/quantization/utils.py
(2 hunks)modelopt/torch/utils/network.py
(3 hunks)tests/_test_utils/examples/run_command.py
(1 hunks)tests/examples/llm_qat/test_llm_qat.py
(3 hunks)
💤 Files with no reviewable changes (1)
- examples/llm_qat/convert_sharded_ckpt.py
🧰 Additional context used
🧬 Code graph analysis (6)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
unwrap_model
(430-454)
examples/llm_qat/main.py (3)
examples/llm_qat/utils.py (1)
get_metrics_with_perplexity
(172-175)modelopt/torch/quantization/plugins/transformers_trainer.py (1)
evaluate
(240-247)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/quantization/utils.py (3)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(62-1182)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)
modelopt/torch/opt/plugins/peft.py (1)
modelopt/torch/quantization/utils.py (1)
get_quantizer_state_dict
(446-456)
tests/examples/llm_qat/test_llm_qat.py (1)
tests/examples/conftest.py (1)
tiny_llama_path
(33-41)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (4)
calibrate_with_adapters
(275-286)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-465)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/conversion.py (3)
modelopt_state
(444-486)save
(489-511)restore_from_modelopt_state
(514-576)modelopt/torch/quantization/model_quant.py (2)
forward_loop
(95-96)quantize
(132-227)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (10)
examples/llm_qat/main.py (2)
41-41
: Import relocation LGTMSwitching
get_metrics_with_perplexity
to the localutils
keeps training code decoupled from trainer internals.
263-263
: Nice: rank‑0 “Training completed.” logGood UX for multi-rank runs.
modelopt/torch/utils/network.py (2)
73-77
: Wrapper table extension LGTMAdding FSDP support in
SUPPORTED_WRAPPERS
is correct and unblocks unwrap/state‑dict flows across backends.
79-86
: Conditional DeepSpeed support LGTMDynamic import and map extension are appropriate; no hard dependency.
modelopt/torch/opt/conversion.py (2)
383-384
: Force-unwrapping semantics changed — verify wrapper coverage and failure mode.Switching to
unwrap_model(..., force_unwrap=True)
skips warnings/errors and may raise if a supported wrapper’s expected attribute is absent. Please verify SUPPORTED_WRAPPERS reliably covers DDP, FSDP/FSDP2, and DeepSpeed in your test matrix to avoid surprise ValueErrors when attributes drift.
469-471
: LGTM: explicit force‑unwrap for state capture.Mirrors the apply path and helps avoid wrapper leakage into saved state.
modelopt/torch/opt/plugins/peft.py (1)
60-63
: Centralized quantizer snapshot is the right call.Using
get_quantizer_state_dict(self)
removes duplication and avoidsmodel.state_dict()
pitfalls under FSDP.examples/llm_qat/simple_qat_train.py (1)
90-93
: Confirm config symbol source.Defaulting to string is fine; ensure
NVFP4_DEFAULT_CFG
(and other choices) are exported atmtq.config
or re-exported atmtq
. If not, the resolution below will matter.tests/examples/llm_qat/test_llm_qat.py (1)
36-37
: Good call setting up a free port.Reduces DDP/Accelerate flakiness in CI.
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
261-293
: Accelerate FSDP2 patch looks sound.Hiding quantizer buffers during prepare avoids FS(D)P2’s “all buffers must be sharded” assumption. Good containment and restoration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
examples/llm_qat/launch.sh (3)
63-66
: Guard divide-by-zero and clamp save_steps to ≥1.
torch.cuda.device_count()
can be 0 (or Python/Torch may be unavailable), causing division by zero. Also when GPU_COUNT > 192, integer division yields 0.-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +GPU_COUNT=$(python - <<'PY' +try: + import torch + print(torch.cuda.device_count()) +except Exception: + print(0) +PY +) +# Fallbacks and clamps +[[ "$GPU_COUNT" =~ ^[0-9]+$ ]] || GPU_COUNT=1 +(( GPU_COUNT > 0 )) || GPU_COUNT=1 +# Calculate save_steps per GPU, ensure at least 1 +DEFAULT_SAVE_STEPS=$(( 192 / GPU_COUNT )) +(( DEFAULT_SAVE_STEPS > 0 )) || DEFAULT_SAVE_STEPS=1
88-92
: Quote-safe and unset-safe QUANT_CFG check.Unquoted
-z $QUANT_CFG
can error or misbehave when unset or with spaces/hyphens.-if [ -z $QUANT_CFG ]; then - QUANT_ARGS="" -else - QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" -fi +if [ -n "${QUANT_CFG:-}" ]; then + QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" +else + QUANT_ARGS="" +fi
95-97
: Unset/empty-safe MAX_STEPS handling.Same quoting problem here.
-if [ ! -z $MAX_STEPS ]; then +if [ -n "${MAX_STEPS:-}" ]; then OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" fi
🧹 Nitpick comments (4)
examples/llm_qat/launch.sh (4)
21-26
: Add a portable to_lower helper (avoid Bash 4+ dependency).The script uses
${var,,}
later, which breaks on macOS’ Bash 3.2. Add a tiny helper here and use it where you need lowercase comparisons.Add this function just below
parse_value
:to_lower() { printf '%s' "${1:-}" | tr '[:upper:]' '[:lower:]'; }
52-55
: Fix invalid-argument error message.Currently prints only the substring after “=”, which is confusing. Print the full flag.
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument '%s'\n" "$1"
99-108
: Lowercasing requires Bash 4+; use helper and warn when overriding backend due to --compress.
${var,,}
breaks on macOS Bash. Also, silently forcing DDP when--compress True
can surprise users—emit an info message.Apply both changes (uses the
to_lower
helper suggested above):-if [[ "${USE_FSDP2,,}" == "true" ]]; then +if [[ "$(to_lower "${USE_FSDP2:-}")" == "true" ]]; then echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead." BACKEND="fsdp2" fi -# if compress is true, set backend to ddp -if [[ "${COMPRESS,,}" == "true" ]]; then - BACKEND="ddp" +# if compress is true, set backend to ddp +if [[ "$(to_lower "${COMPRESS:-}")" == "true" ]]; then + if [[ "$(to_lower "${BACKEND:-}")" != "ddp" ]]; then + echo "Info: --compress True forces --backend=ddp (overriding '$BACKEND')." + fi + BACKEND="ddp" fi
139-142
: Distillation backend check: avoid${var,,}
.Use the helper for portability.
- if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then + if [[ "$(to_lower "${BACKEND:-}")" == "fsdp1" || "$(to_lower "${BACKEND:-}")" == "fsdp2" ]]; then FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False" fi
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/llm_qat/launch.sh
(4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-docs
🔇 Additional comments (5)
examples/llm_qat/launch.sh (5)
21-25
: parse_value helper looks good.Compact and correct handling of both
--flag=value
and--flag value
.
175-176
: Confirm hard-coded gradient checkpointing.It’s always enabled. Please confirm it’s intended across all backends/configs.
86-86
: Default BACKEND is reasonable.Defaulting to fsdp1 keeps current behavior while allowing overrides.
145-177
: Use an args array for the accelerate command; verify deepspeed gradient_checkpointing
- Replace the CMD string in examples/llm_qat/launch.sh (lines 145–177) with an args array and invoke "${args[@]}" to avoid shell-quoting and path/spacing pitfalls.
- Confirm examples/llm_qat/accelerate_config/deepspeed.yaml does not set gradient_checkpointing that conflicts with --gradient_checkpointing True — the automated search returned no match; verify manually.
111-133
: Use helper to lowercase BACKEND and addfsdp
alias to error message.File: examples/llm_qat/launch.sh Lines: 111-133 — replace bash-only
${BACKEND,,}
with the repoto_lower
helper for portability; verified accelerate configs present.-case "${BACKEND,,}" in +case "$(to_lower "${BACKEND:-}")" in "fsdp1"|"fsdp") CONFIG_FILE="fsdp1.yaml" FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" ;; "fsdp2") echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers." CONFIG_FILE="fsdp2.yaml" FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" ;; "ddp") CONFIG_FILE="ddp.yaml" FSDP_ARGS="" ;; "deepspeed") CONFIG_FILE="deepspeed.yaml" FSDP_ARGS="" ;; *) - echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed" + echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp (alias fsdp1), fsdp2, ddp, deepspeed" exit 1 ;; esacEnsure
to_lower
is defined in the repo (or keep${BACKEND,,}
only if the script intentionally requires bash ≥4).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
103-115
: Fix: handle QuantizeConfig vs dict in check_awq_smoothquant()quant_cfg may be a QuantizeConfig instance (not dict); calling .get() will raise AttributeError. Support both types.
Apply this diff:
-def check_awq_smoothquant(quant_cfg): +def check_awq_smoothquant(quant_cfg): # TODO: Remove this once deepspeed for AWQ and SmoothQuant is added """Get the quantization type from the configuration.""" if quant_cfg is None: return False - algorithm = quant_cfg.get("algorithm", {}) + # Accept dict-like or QuantizeConfig + if isinstance(quant_cfg, dict): + algorithm = quant_cfg.get("algorithm", {}) + else: + # QuantizeConfig or object with attribute + algorithm = getattr(quant_cfg, "algorithm", {}) or {} is_awq_smoothquant = False # Check SmoothQuant and AWQ if algorithm and ("smoothquant" in algorithm or "awq" in algorithm): is_awq_smoothquant = True return is_awq_smoothquant
♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
197-201
: Bug: dataset can be None → TypeError on len(dataset)If both train_dataset and eval_dataset are None, this crashes. Provide a clear error and compute length after selection.
Apply this diff:
- dataset = self.train_dataset if self.train_dataset is not None else self.eval_dataset - num_samples = min(self.quant_args.calib_size, len(dataset)) # type: ignore [union-attr] + dataset = self.eval_dataset if self.eval_dataset is not None else self.train_dataset + assert dataset is not None, "Calibration requires either eval or train dataset." + num_samples = min(self.quant_args.calib_size, len(dataset)) # type: ignore [arg-type] dataset = torch.utils.data.Subset(dataset, list(range(num_samples))) data_loader = self.get_eval_dataloader(dataset)
🧹 Nitpick comments (8)
examples/llm_qat/accelerate_config/fsdp1.yaml (1)
7-7
: Activation checkpointing enabled: verify perf/compat across backends and toolchain versions.Turning this on reduces memory but increases recompute. Please:
- Confirm torch/accelerate versions in CI support
fsdp_activation_checkpointing
withfsdp_use_orig_params: true
and FSDP v1.- Watch test/runtime timeouts; QAT runs may slow notably.
- Sanity‑check numerics and resume-from-checkpoint with this setting on.
Optionally gate via a CLI flag in launch.sh so users can toggle per run.
modelopt/torch/quantization/plugins/transformers_trainer.py (7)
163-169
: Idempotency: avoid double‑patching accelerate.prepareIf init is called multiple times or subclasses re‑invoke the patch, _original_prepare could be overwritten. Guard the patch.
Apply this diff:
- self._patch_accelerate_for_fsdp2_fix() + # Patch once + if not hasattr(self.accelerator, "_original_prepare"): + self._patch_accelerate_for_fsdp2_fix()
170-189
: Distributed save ordering: add post‑save barrierYou barrier before saving, but not after. Add a post‑save barrier to ensure all ranks see a fully written file before proceeding.
Apply this diff:
- if self.args.should_save: - torch.save(modelopt_full_state, self._modelopt_state_path) + if self.args.should_save: + torch.save(modelopt_full_state, self._modelopt_state_path) + if torch.distributed.is_initialized(): + torch.distributed.barrier()
190-194
: Robust load: add map_location to avoid device mismatchestorch.load without map_location can fail on CPU‑only envs or pick wrong GPU. Load to CPU then let restore helpers move tensors.
Apply this diff:
- modelopt_full_state = torch.load(self._modelopt_state_path, weights_only=False) + modelopt_full_state = torch.load( + self._modelopt_state_path, map_location="cpu", weights_only=False + )
202-208
: Clarify intent: use self.model by design; fix misleading comment/paramPer team learning, forward pass should invoke self.model, not the unwrapped model parameter. Rename the unused parameter and update the comment to prevent future regressions.
Apply this diff:
- def forward_loop(model): + def forward_loop(_unused_unwrapped_model): for batch in tqdm(data_loader, desc="Calibrating"): batch = self._prepare_inputs(batch) - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop - self.model(**batch) + # Important: Forward with self.model to retain wrapper hooks (DDP/FSDP2/DeepSpeed). + # Do not use the unwrapped model parameter passed by mtq.quantize(). + self.model(**batch)Note: Using the retrieved learning for this file; keeping self.model is intentional.
203-203
: Reduce multi‑rank tqdm spam during calibrationDisable progress bars on non‑main processes.
Apply this diff:
- for batch in tqdm(data_loader, desc="Calibrating"): + for batch in tqdm( + data_loader, + desc="Calibrating", + disable=not self.accelerator.is_local_main_process, + ):
319-321
: Device placement: prefer accelerator.device over .cuda()Calling .cuda() can select the wrong device under multi‑GPU/Accelerate. Use to(self.accelerator.device).
Apply this diff:
- self.model.cuda() + self.model.to(self.accelerator.device)
352-354
: Avoid stringly‑typed state_dict_type checksRelying on "SHARDED_STATE_DICT" substring is brittle. Prefer comparing against the enum/type provided by the FSDP plugin if available (e.g., FSDPStateDictType.SHARDED_STATE_DICT).
If importing the enum is not feasible here, at least gate on an attribute, e.g. hasattr(self.accelerator.state.fsdp_plugin, "state_dict_type") and compare to known constants from that module.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/llm_qat/accelerate_config/fsdp1.yaml
(1 hunks)examples/llm_qat/launch.sh
(4 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_qat/launch.sh
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
261-293
: Quantizer buffers: verify to_empty()/restore sequence is safetq.to_empty() frees buffer storage; after prepare you only restore _non_persistent_buffers_set, not buffer contents/devices. Confirm that later calibration/restore repopulates buffers correctly for all TensorQuantizer modules under FSDP2.
Would you run a quick check across backends (FSDP2, DDP) that prints dtype/device and shapes of a few representative quantizer buffers before and after prepare, and post‑calibration, to ensure they’re valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/llm_qat/accelerate_config/fsdp1.yaml (1)
12-17
: Fix mis-keyed FSDP option: FULL_SHARD belongs to sharding strategy, not reshard_after_forward.
fsdp_reshard_after_forward
expects a boolean.FULL_SHARD
should be set onfsdp_sharding_strategy
. This mis-key will be ignored or cause unexpected behavior.Apply this diff:
fsdp_forward_prefetch: false - fsdp_offload_params: false - fsdp_reshard_after_forward: FULL_SHARD + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true fsdp_state_dict_type: FULL_STATE_DICTexamples/llm_qat/launch.sh (1)
63-66
: Bug: division by zero when no GPUs.GPU_COUNT can be 0 on CPU‑only runners; DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) will fail.
-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +GPU_COUNT=$(python - <<'PY' +import os +try: + import torch + n = torch.cuda.device_count() +except Exception: + n = 0 +print(max(1, int(n))) +PY +) +# Calculate save_steps (fallback to 192 when CPU-only) +DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
♻️ Duplicate comments (5)
examples/llm_qat/accelerate_config/fsdp1.yaml (1)
23-23
:num_processes: gpu
— OK per earlier context; verify CI/runtime Accelerate version.Per your note, this passes through to torchrun. Please ensure the deployed Accelerate version supports this token across all backends to avoid surprises on older runners.
examples/llm_qat/accelerate_config/deepspeed.yaml (1)
17-17
:num_processes: gpu
— acknowledged as intentional.Keeping as-is per earlier discussion; just ensure runners use the compatible Accelerate version.
modelopt/torch/quantization/utils.py (1)
28-29
: Right call: use get_unwrapped_name(name, model) to stabilize keys across wrappers.modelopt/torch/quantization/plugins/transformers_trainer.py (2)
202-208
: Forward pass via self.model is intentional — approved.Per your note, use self.model(**batch) instead of the unwrapped parameter to respect Trainer hooks.
197-201
: Fix None dataset path in calibration.If both train and eval datasets are None, len(dataset) raises. Select dataset first and assert it exists.
- dataset = self.train_dataset if self.train_dataset is not None else self.eval_dataset - num_samples = min(self.quant_args.calib_size, len(dataset)) # type: ignore [union-attr] + dataset = self.eval_dataset if self.eval_dataset is not None else self.train_dataset + assert dataset is not None, "Calibration requires either eval or train dataset." + num_samples = min(self.quant_args.calib_size, len(dataset))
🧹 Nitpick comments (18)
modelopt/torch/quantization/calib/histogram.py (1)
160-163
: Broaden warning to all distributed wrappers (not just DDP).Message still ends with “DDP modules,” which re-introduces wrapper specificity. Recommend generic phrasing.
- " method is to use the same calibration dataset across all distributed data" - " parallel groups so that `amax` is the same for all DDP modules." + " method is to use the same calibration dataset across all distributed data" + " parallel groups so that `amax` is the same across all ranks/modules."modelopt/torch/quantization/nn/modules/quant_module.py (1)
161-165
: Good fix: ensures cleanup even on exceptions.The try/finally wrapper guarantees
_enable_weight_quantization
is reset. Consider a nested-safe counter (increment/decrement) if nested contexts are possible, otherwise this is fine.examples/llm_qat/utils.py (1)
172-175
: Avoid tensor creation; guard missing/overflow for perplexity.Use math.exp on the float, handle missing eval_loss, and prevent overflow to keep this helper robust and cheap.
-def get_metrics_with_perplexity(metrics): - """Add perplexity to the metrics.""" - metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics} - return metrics +from math import exp + +def get_metrics_with_perplexity(metrics): + """Add perplexity from eval_loss if present; fall back gracefully.""" + loss = metrics.get("eval_loss") + if loss is None: + return metrics + try: + ppl = float(exp(float(loss))) + except OverflowError: + ppl = float("inf") + return {**metrics, "perplexity": ppl}tests/_test_utils/examples/run_command.py (1)
35-44
: Set MASTER_ADDR alongside MASTER_PORT for deterministic local runs.Helps avoid cross-env surprises when a global address is pre-set or missing.
def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False): print(f"[{example_path}] Running command: {cmd_parts}") env = os.environ.copy() if setup_free_port: free_port = get_free_port() env["MASTER_PORT"] = str(free_port) + env.setdefault("MASTER_ADDR", "127.0.0.1") subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True)
modelopt/torch/utils/network.py (1)
439-447
: Improve error context when force_unwrap fails.Include the expected attribute name in the error for faster diagnosis.
- except AttributeError: + except AttributeError: raise ValueError( - f"Model of type {type(model)} could not be forcefully unwrapped! Please manually" - " unwrap the model before passing it in." + f"Model of type {type(model)} could not be forcefully unwrapped " + f"(missing attr '{SUPPORTED_WRAPPERS.get(type(model), 'module')}'). " + "Please unwrap the model before passing it in." )modelopt/torch/quantization/conversion.py (1)
115-123
: Avoid recomputingquantizer_state(model)
twice.Store it once to reduce traversal work on large models.
- quantizer_state_dict = metadata["quantizer_state"] - unmatched_keys = quantizer_state_dict.keys() - quantizer_state(model).keys() - extra_keys = quantizer_state(model).keys() - quantizer_state_dict.keys() + quantizer_state_dict = metadata["quantizer_state"] + current_keys = quantizer_state(model).keys() + unmatched_keys = quantizer_state_dict.keys() - current_keys + extra_keys = current_keys - quantizer_state_dict.keys()examples/llm_qat/main.py (2)
189-191
: Clean stray characters in comment.There’s a trailing “åå”.
- # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.åå + # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.
266-269
: Guard perplexity computation when eval_loss is absent.Prevents KeyError on custom metrics.
- metrics = trainer.evaluate() - metrics = get_metrics_with_perplexity(metrics) - print_rank_0(f"Evaluation results: \n{metrics}") + metrics = trainer.evaluate() + if "eval_loss" in metrics: + metrics = get_metrics_with_perplexity(metrics) + print_rank_0(f"Evaluation results: \n{metrics}")modelopt/torch/opt/plugins/peft.py (1)
85-95
: Pre-validate saved vs. current quantizer keys for clearer errors.Direct indexing can raise a bare KeyError on mismatches. Add a short key-set check (mirrors
conversion.restore_quantizer_state
) for better diagnostics.if os.path.isfile(_get_quantizer_state_save_path(model_id)): from modelopt.torch.quantization.nn import TensorQuantizer quantizer_state_dict = torch.load( _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False ) + # Validate keys before loading for clearer messaging + expected = { + get_unwrapped_name(n, self) + for n, m in self.named_modules() + if isinstance(m, TensorQuantizer) + } + got = set(quantizer_state_dict.keys()) + missing = expected - got + extra = got - expected + if missing or extra: + raise ValueError(f"Quantizer state key mismatch. missing={missing}, extra={extra}") for name, module in self.named_modules(): if isinstance(module, TensorQuantizer): - module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)]) + key = get_unwrapped_name(name, self) + module.load_state_dict(quantizer_state_dict[key])modelopt/torch/opt/conversion.py (1)
590-592
: Harden restore semantics: reject wrapped models consistently.restore() now asserts unwrapped input; restore_from_modelopt_state() does not. To avoid accidental misuse via direct calls, mirror the same check there.
def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]) -> nn.Module: """Restore the model architecture from the modelopt state dictionary based on the user-provided model.""" # initialize ModelLikeModule if needed. - model = model if isinstance(model, nn.Module) else ModelLikeModule(model) + model = model if isinstance(model, nn.Module) else ModelLikeModule(model) + # Keep behavior consistent with `restore()`: do not allow wrapped models here either. + from modelopt.torch.utils import unwrap_model + model = unwrap_model(model, raise_error=True)tests/examples/llm_qat/test_llm_qat.py (2)
39-45
: Param sweep over backends is great; add env‑conditional skips for optional backends.CI often lacks DeepSpeed (and sometimes FS*2). Guard the tests to skip when the backend isn’t available instead of failing.
@pytest.mark.parametrize("backend", [ "fsdp1", "fsdp2", "deepspeed", "ddp", ]) def test_llama_qat_int4w_int8a(tiny_llama_path, tmp_path, backend): + if backend == "deepspeed": + pytest.importorskip("deepspeed") + if backend == "fsdp2": + torch = pytest.importorskip("torch") + from packaging.version import Version + if Version(torch.__version__) < Version("2.3"): + pytest.skip("FSDP2 requires torch>=2.3")Repeat the small guard in test_llama_qat_int4w_int8a_direct_qat.
76-86
: Direct QAT test: consider marking slow to keep CI wall‑time sane.Mark as slow if your CI budget is tight.
-@pytest.mark.parametrize("backend", [ +@pytest.mark.slow +@pytest.mark.parametrize("backend", [modelopt/torch/quantization/utils.py (1)
459-466
: Consider non‑strict load for forward/backward compat.Older/newer TensorQuantizer shapes/keys may drift; strict=False reduces fragility while _load_from_state_dict still handles specifics.
- module.load_state_dict(quantizer_state_dict[key]) + module.load_state_dict(quantizer_state_dict[key], strict=False)If you expect strict matching across versions, ignore this.
examples/llm_qat/launch.sh (2)
88-93
: Quote QUANT_CFG in test to avoid word‑splitting.Minor robustness fix.
-if [ -z $QUANT_CFG ]; then +if [ -z "${QUANT_CFG}" ]; then QUANT_ARGS="" else QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" fi
53-55
: Error message prints the value, not the invalid flag.Use
$1 instead of $ {1#*=} for clarity.- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument %s\n" "$1"modelopt/torch/quantization/plugins/transformers_trainer.py (3)
170-186
: Potentially ineffective filter for modelopt_state_dict.state is a (mode_str, state_dict) tuple; checking "kd_loss" in state won’t inspect metadata. The filter likely keeps everything.
- modelopt_state["modelopt_state_dict"] = [ - state - for state in modelopt_state["modelopt_state_dict"] - if "kd_loss" not in state and "export_student" not in state - ] + filtered = [] + for m_str, m_state in modelopt_state["modelopt_state_dict"]: + meta = m_state.get("metadata", {}) + if m_str in {"distill", "kd_loss"} or meta.get("export_student", False): + continue + filtered.append((m_str, m_state)) + modelopt_state["modelopt_state_dict"] = filteredAdjust the mode names if your registry uses different identifiers.
240-247
: FSDP2 eval‑only hack looks correct; tiny nit: avoid creating opt on empty params.If the model has no parameters (rare), next(self.model.parameters()) raises. Optional safeguard:
- dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0) + first_param = next(self.model.parameters(), None) + dummy_optimizer = torch.optim.SGD([first_param] if first_param is not None else [torch.nn.Parameter(torch.zeros(1, device=self.accelerator.device))], lr=0.0)
261-293
: Accelerate FSDP2 patch is clever; leave breadcrumbs for future updates.You’re touching private attrs (_non_persistent_buffers_set); add a one‑liner warning so future Accelerate upgrades are audited.
def _modelopt_prepare(self, *args, **kwargs): + # NOTE: Relies on private Accelerate/FSDP internals; re‑verify if Accelerate/FSDP2 is updated. if not self.is_fsdp2: return self._original_prepare(*args, **kwargs)
Please confirm the targeted Accelerate version(s) in CI where this patch is validated.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
examples/llm_qat/accelerate_config/deepspeed.yaml
(1 hunks)examples/llm_qat/accelerate_config/fsdp1.yaml
(1 hunks)examples/llm_qat/convert_sharded_ckpt.py
(0 hunks)examples/llm_qat/launch.sh
(4 hunks)examples/llm_qat/main.py
(2 hunks)examples/llm_qat/simple_qat_train.py
(2 hunks)examples/llm_qat/utils.py
(1 hunks)modelopt/torch/opt/conversion.py
(2 hunks)modelopt/torch/opt/dynamic.py
(1 hunks)modelopt/torch/opt/plugins/peft.py
(2 hunks)modelopt/torch/quantization/calib/histogram.py
(1 hunks)modelopt/torch/quantization/conversion.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_module.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)modelopt/torch/quantization/utils.py
(2 hunks)modelopt/torch/utils/network.py
(3 hunks)tests/_test_utils/examples/run_command.py
(1 hunks)tests/examples/llm_qat/test_llm_qat.py
(3 hunks)
💤 Files with no reviewable changes (1)
- examples/llm_qat/convert_sharded_ckpt.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T16:36:42.871Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: examples/llm_qat/accelerate_config/deepspeed.yaml:17-17
Timestamp: 2025-09-15T16:36:42.871Z
Learning: In Hugging Face Accelerate configuration YAML files, num_processes can accept the string "gpu" as a value, which gets passed through to torch run under the hood and functions correctly, despite standard documentation showing integer values.
Applied to files:
examples/llm_qat/accelerate_config/deepspeed.yaml
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (8)
modelopt/torch/quantization/conversion.py (3)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
set_from_modelopt_state
(1122-1140)get_modelopt_state
(1105-1120)get_modelopt_state
(1246-1248)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)
modelopt/torch/opt/plugins/peft.py (2)
modelopt/torch/quantization/utils.py (1)
get_quantizer_state_dict
(446-456)modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)
modelopt/torch/opt/dynamic.py (1)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)
examples/llm_qat/main.py (3)
examples/llm_qat/utils.py (1)
get_metrics_with_perplexity
(172-175)modelopt/torch/quantization/plugins/transformers_trainer.py (1)
evaluate
(240-247)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(62-1182)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
unwrap_model
(430-454)
tests/examples/llm_qat/test_llm_qat.py (1)
tests/examples/conftest.py (1)
tiny_llama_path
(33-41)
modelopt/torch/quantization/plugins/transformers_trainer.py (6)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (5)
calibrate_with_adapters
(275-286)disable_lora_quantizers_in_config
(289-296)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/conversion.py (3)
modelopt_state
(444-486)save
(489-507)restore_from_modelopt_state
(510-567)modelopt/torch/quantization/model_quant.py (3)
forward_loop
(95-96)quantize
(132-227)print_quant_summary
(463-470)modelopt/torch/distill/plugins/huggingface.py (1)
save_model
(48-92)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (13)
modelopt/torch/utils/network.py (1)
73-86
: Wrapper matrix looks good; DeepSpeed inclusion is correct.Nice to see deepspeed handled conditionally.
If you rely on is_parallel() for routing, consider adding FSDP/DeepSpeed there in a follow-up.
examples/llm_qat/simple_qat_train.py (1)
90-92
: Harden runtime config lookup for --quant-cfgGuard the getattr and raise a clear ValueError listing valid choices; the provided verification script produced no output so I could not confirm that every entry in mtq.config.choices is defined on modelopt.torch.quantization — re-run the check or verify manually.
Location: examples/llm_qat/simple_qat_train.py lines 90–92 (also 124–125)
- model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate) + cfg = getattr(mtq, args.quant_cfg, None) + if cfg is None: + raise ValueError( + f"Unknown quant config '{args.quant_cfg}'. Valid choices: {mtq.config.choices}" + ) + model = mtq.quantize(model, cfg, calibrate)examples/llm_qat/accelerate_config/fsdp1.yaml (1)
7-7
: Enabling activation checkpointing: looks good.Change is appropriate for memory savings under FSDP. No further action.
modelopt/torch/quantization/conversion.py (2)
124-133
: Wrapper-aware naming in restore path — correct.Passing
model
toget_unwrapped_name
for bothTensorQuantizer
andQuantModule
matches the updated API and fixes wrapper-prefix mismatches (DDP/FSDP/DeepSpeed).
169-172
: Consistent keying for saved metadata — good.Using
get_unwrapped_name(n, model)
ensures saved metadata lines up with restore under wrappers.examples/llm_qat/accelerate_config/deepspeed.yaml (1)
1-24
: DeepSpeed Accelerate config — sensible defaults.Zero-3 with no offload and bf16 is consistent with the training flow.
Please confirm DS/Accelerate versions in CI match these fields (zero3_init_flag, zero3_save_16bit_model, offload_*_device: none) to avoid schema drift across versions.
modelopt/torch/opt/plugins/peft.py (1)
59-65
: DRY improvement: centralize quantizer state collection — good.Replacing ad-hoc collection with
get_quantizer_state_dict(self)
reduces duplication and handles FSDP safely.modelopt/torch/opt/conversion.py (1)
383-384
: Early force‑unwrap in apply_mode looks correct.Using force_unwrap=True here prevents nested wrapper state from leaking into ModelOpt state. No issues.
Please confirm DeepSpeedEngine is included in SUPPORTED_WRAPPERS with the correct attribute (usually "module") so force_unwrap doesn’t throw for DS.
tests/examples/llm_qat/test_llm_qat.py (1)
36-37
: Nice: free port setup prevents DDP/FS port clashes in CI.*modelopt/torch/quantization/utils.py (1)
446-456
: Quantizer state export avoids FSDP hangs — good.Implementation is lean and avoids model.state_dict(); keying via get_unwrapped_name is correct.
examples/llm_qat/launch.sh (1)
99-136
: Backend routing LGTM; compress→ddp fallback is a sensible default.Please ensure accelerate_config/{fsdp1,fsdp2,ddp,deepspeed}.yaml ship in wheels/sdists so tests don’t fail when installed from package.
modelopt/torch/quantization/plugins/transformers_trainer.py (2)
172-175
: Barrier placement is fine; just ensure every rank calls this method.Since _save_modelopt_state_with_weights() is invoked on all ranks, the barrier won’t deadlock. Keep it that way.
249-259
: Good: switch to FULL_STATE_DICT for final save when not in train.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/utils/network.py (1)
599-613
: DP prefix handling added (addresses prior feedback); minor robustness nit.Nice: now strips module. for DataParallel too. Consider stripping repeated prefixes to handle nested wrapping chains.
Apply this diff:
- if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)) or ( + if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)) or ( DeepSpeedEngine is not None and isinstance(model, DeepSpeedEngine) ): - name = name.removeprefix("module.") + while name.startswith("module."): + name = name[len("module."):]
🧹 Nitpick comments (2)
modelopt/torch/utils/network.py (2)
90-93
: Align is_parallel() with SUPPORTED_WRAPPERS.Currently it misses FSDP and DeepSpeed. Use the registry to avoid drift.
Apply this diff:
def is_parallel(model: nn.Module) -> bool: - """Check if a PyTorch model is parallelized.""" - return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) + """Check if a PyTorch model is wrapped by a supported parallel/distributed wrapper.""" + return any(isinstance(model, t) for t in SUPPORTED_WRAPPERS)
440-454
: Use isinstance instead of exact type match when unwrapping.Exact type checks can miss subclasses/proxies (e.g., FSDP variants).
Apply this diff:
- try: - if type(model) in SUPPORTED_WRAPPERS: - return getattr(model, SUPPORTED_WRAPPERS[type(model)]) + try: + for wrapper_t, attr in SUPPORTED_WRAPPERS.items(): + if isinstance(model, wrapper_t): + return getattr(model, attr) except AttributeError: raise ValueError( f"Model of type {type(model)} could not be forcefully unwrapped! Please manually" " unwrap the model before passing it in." ) - if type(model) in SUPPORTED_WRAPPERS: + for wrapper_t, attr in SUPPORTED_WRAPPERS.items(): + if isinstance(model, wrapper_t): if raise_error: raise ValueError(msg or f"Model {model} is wrapped by {type(model)}!") elif warn: warnings.warn(msg or f"Model {model} is wrapped by {type(model)}; unwrapping...") - return getattr(model, SUPPORTED_WRAPPERS[type(model)]) + return getattr(model, attr) return model
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/opt/dynamic.py
(1 hunks)modelopt/torch/utils/network.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/opt/dynamic.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/utils/network.py (2)
73-77
: Wrapper registry centralization and DeepSpeed dynamic add: LGTM.Good move to make SUPPORTED_WRAPPERS the single source of truth and register DeepSpeed at runtime.
Also applies to: 79-86
599-613
: Verified: call sites updated and Python requirement OK.AST scan found no single-arg/zero-arg calls to get_unwrapped_name — all occurrences pass model (modelopt/torch/quantization/utils.py:455, quantization/conversion.py:126/131/169, opt/dynamic.py:1276, opt/plugins/peft.py:93). setup.py declares python_requires=">=3.10,<3.13" (pyproject targets py310), so removeprefix/removesuffix and PEP 585 generics are supported.
WalkthroughAdds a DeepSpeed Accelerate config, enables FSDP activation checkpointing, centralizes model-aware unwrapping and quantizer state snapshot/restore, implements a stateful FS‑DP2-aware quantization flow with explicit save/restore, refactors example launch/tests for backend selection and port handling, and removes the sharded-checkpoint conversion script and its automated calls. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Launch as launch.sh
participant Accel as Accelerate
participant Trainer as QAT/QAD Trainer
participant Model as Model
participant Disk as Storage
User->>Launch: run example (--backend ...)
Launch->>Launch: parse args, set BACKEND
Launch->>Accel: select config (fsdp1/fsdp2/ddp/deepspeed)
Accel-->>Launch: environment prepared
Launch->>Trainer: start
rect rgba(220,235,255,0.35)
Note right of Trainer: FS‑DP2-specific adjustments
Trainer->>Trainer: _patch_accelerate_for_fsdp2_fix()
end
alt modelopt state exists
Trainer->>Disk: read modelopt + quantizer state
Trainer->>Trainer: _restore_modelopt_state_with_weights()
Trainer->>Model: restore weights and quantizer buffers
else no prior state
Trainer->>Model: run subset forward-loop for calibration
Trainer->>Trainer: apply quantization (mtq.quantize) and optional compress
Trainer->>Disk: _save_modelopt_state_with_weights()
end
Trainer->>User: metrics (perplexity via examples.utils.get_metrics_with_perplexity)
sequenceDiagram
autonumber
participant PEFT as PEFT Plugin
participant Utils as quantization.utils
participant Model as Model
participant Disk as Storage
PEFT->>Utils: get_quantizer_state_dict(Model)
Utils->>Model: traverse modules, collect TensorQuantizer.state_dict (keys via get_unwrapped_name(name, model))
Utils-->>PEFT: quantizer_state_dict
PEFT->>Disk: write modelopt + quantizer_state_dict
Disk->>PEFT: read modelopt + quantizer_state_dict
PEFT->>Utils: set_quantizer_state_dict(Model, quantizer_state_dict)
Utils->>Model: restore per-quantizer state by model-aware keys
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
modelopt/torch/opt/dynamic.py (1)
1281-1335
: Bug: select() mismatches keys under wrappers; normalize names before comparisons.
config()
now returns unwrapped keys, butselect()
compares against raw names fromnamed_hparams()
. Under DDP/DeepSpeed, this rejects valid configs. Normalize both sides withget_unwrapped_name(..., self.model)
.Apply this diff:
def select(self, config: dict[str, Any], strict: bool = True) -> None: @@ - configurables = dict(self.named_hparams(configurable=True)) + # Normalize keys to work under wrappers (DDP/DeepSpeed). + _unwrap = lambda n: get_unwrapped_name(n, self.model) + normalized_config = {_unwrap(k): v for k, v in config.items()} + configurables_raw = dict(self.named_hparams(configurable=True)) + configurables = {_unwrap(n): hp for n, hp in configurables_raw.items()} @@ - check_non_configurable = any( - name in config and name not in configurables for name, hp in self.named_hparams() - ) + all_hps_unwrapped = {_unwrap(n): hp for n, hp in self.named_hparams()} + check_non_configurable = any( + n in normalized_config and n not in configurables for n in all_hps_unwrapped + ) @@ - unexpected_keys = dict.fromkeys(config.keys(), True) + unexpected_keys = dict.fromkeys(normalized_config.keys(), True) @@ - for name, hparam in configurables.items(): - if name in config: - hparam.active = config[name] + for name, hparam in configurables.items(): + if name in normalized_config: + hparam.active = normalized_config[name] unexpected_keys[name] = False elif strict: missing_keys.append(name) @@ - for name, hparam in self.named_hparams(): - if name in configurables: + for name, hparam in all_hps_unwrapped.items(): + if name in configurables: continue - if name not in config: + if name not in normalized_config: missing_keys.append(name) continue unexpected_keys[name] = False - if hparam.active != config[name]: + if hparam.active != normalized_config[name]: inconsistent_keys.append( - f"{name}: active={hparam.active}, config={config[name]}" + f"{name}: active={hparam.active}, config={normalized_config[name]}" )modelopt/torch/quantization/conversion.py (1)
90-93
: Critical: convert_to_quantized_model return value ignored.This leaves
model
asModelLikeModule
and prevents quantizer insertion before restore.Apply this diff:
- convert_to_quantized_model(model, config) + model, _ = convert_to_quantized_model(model, config)examples/llm_qat/launch.sh (1)
63-66
: Fix divide-by-zero when no GPUs are visible.GPU_COUNT can be 0 in CI/CPU envs, causing DEFAULT_SAVE_STEPS=$((192 / 0)).
Apply:
-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +GPU_COUNT=$(python -c "import torch; import os; print(torch.cuda.device_count())") +# Calculate save_steps (fallback to 1 when no GPUs) +DEFAULT_SAVE_STEPS=$((192 / (GPU_COUNT > 0 ? GPU_COUNT : 1)))
♻️ Duplicate comments (3)
modelopt/torch/opt/dynamic.py (1)
1275-1277
: Fix: pass nn.Module context to get_unwrapped_name — correct.Using
self.model
ensures wrapper prefixes are stripped (DDP/DeepSpeed). Matches prior guidance.modelopt/torch/quantization/utils.py (1)
446-456
: Wrapper-stable quantizer keys — LGTM.Using get_unwrapped_name(name, model) fixes cross-wrapper key drift. Matches prior guidance.
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
197-201
: Guard against missing datasets before len(...).Order can raise TypeError when both datasets are None. Recommend:
- dataset = self.train_dataset if self.train_dataset is not None else self.eval_dataset - num_samples = min(self.quant_args.calib_size, len(dataset)) # type: ignore [union-attr] + dataset = self.eval_dataset if self.eval_dataset is not None else self.train_dataset + assert dataset is not None, "Calibration requires either eval or train dataset." + num_samples = min(self.quant_args.calib_size, len(dataset))
🧹 Nitpick comments (16)
modelopt/torch/quantization/nn/modules/quant_module.py (2)
157-165
: Make quantize_weight re-entrant to avoid disabling on nested contextsWith the current bool flag, nested
with self.quantize_weight()
blocks will disable weight quantization when the inner context exits. Use a depth counter so the flag only flips to False at depth 0.Apply:
@contextlib.contextmanager def quantize_weight(self): """Context in which `self.weight` is quantized.""" - self._enable_weight_quantization = True - try: - yield - finally: - self._enable_weight_quantization = False + depth = getattr(self, "_weight_quantization_depth", 0) + 1 + self._weight_quantization_depth = depth + self._enable_weight_quantization = True + try: + yield + finally: + depth -= 1 + self._weight_quantization_depth = depth + if depth == 0: + self._enable_weight_quantization = False
181-187
: Register a depth counter attribute for re-entrancyInitialize the counter alongside the existing flag.
Apply:
def _setup(self): super()._setup() self._register_temp_attribute( "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight) ) self._register_temp_attribute("_enable_weight_quantization", False) + self._register_temp_attribute("_weight_quantization_depth", 0) self._register_dynamic_attribute("weight", self._get_quantized_weight)
examples/llm_qat/utils.py (1)
172-175
: Guard against missing/invalid eval_loss and avoid unnecessary tensor opsHandle absent
eval_loss
/loss
keys and non-finite values; keep it torch-free for speed.Apply:
-def get_metrics_with_perplexity(metrics): - """Add perplexity to the metrics.""" - metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics} - return metrics +def get_metrics_with_perplexity(metrics): + """Add perplexity to the metrics if loss is present.""" + loss = metrics.get("eval_loss", metrics.get("loss")) + if loss is None: + return metrics + try: + ppl = float(torch.exp(torch.tensor(loss))) + except Exception: + return metrics + return {"perplexity": ppl, **metrics}modelopt/torch/quantization/calib/histogram.py (1)
160-163
: Neutral wording: not all setups are “DDP”Since this path can run under FSDP/DeepSpeed too, avoid “DDP modules” phrasing.
Apply:
- " method is to use the same calibration dataset across all distributed data" - " parallel groups so that `amax` is the same for all DDP modules." + " method is to use the same calibration dataset across all distributed data" + " parallel groups so that `amax` is consistent across all modules."tests/_test_utils/examples/run_command.py (1)
35-44
: Set MASTER_ADDR for robustness on multi-NIC hostsHelps avoid rendezvous surprises on machines with multiple interfaces.
Apply:
def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False): print(f"[{example_path}] Running command: {cmd_parts}") env = os.environ.copy() if setup_free_port: free_port = get_free_port() env["MASTER_PORT"] = str(free_port) + env.setdefault("MASTER_ADDR", "127.0.0.1") subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True)
modelopt/torch/utils/network.py (3)
79-83
: Narrow bare except to ImportErrorAvoid swallowing unrelated runtime errors during import.
Apply:
-try: +try: from deepspeed.runtime.engine import DeepSpeedEngine -except: # noqa: E722 +except ImportError: DeepSpeedEngine = None
438-454
: Use isinstance against SUPPORTED_WRAPPERS for unwrapping (handles subclasses too)Current
type(model) in SUPPORTED_WRAPPERS
misses subclasses and proxies.Apply:
- if force_unwrap: - try: - if type(model) in SUPPORTED_WRAPPERS: - return getattr(model, SUPPORTED_WRAPPERS[type(model)]) + if force_unwrap: + try: + for wrapper_type, attr in SUPPORTED_WRAPPERS.items(): + if isinstance(model, wrapper_type): + return getattr(model, attr) except AttributeError: raise ValueError( f"Model of type {type(model)} could not be forcefully unwrapped! Please manually" " unwrap the model before passing it in." ) - if type(model) in SUPPORTED_WRAPPERS: + for wrapper_type, attr in SUPPORTED_WRAPPERS.items(): + if isinstance(model, wrapper_type): if raise_error: raise ValueError(msg or f"Model {model} is wrapped by {type(model)}!") elif warn: warnings.warn(msg or f"Model {model} is wrapped by {type(model)}; unwrapping...") - return getattr(model, SUPPORTED_WRAPPERS[type(model)]) + return getattr(model, attr) return model
90-93
: is_parallel should align with SUPPORTED_WRAPPERS (include FSDP/DeepSpeed when present)Leverage the canonical wrapper list to avoid drift.
Apply:
def is_parallel(model: nn.Module) -> bool: """Check if a PyTorch model is parallelized.""" - return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) + return isinstance(model, tuple(SUPPORTED_WRAPPERS.keys()))examples/llm_qat/simple_qat_train.py (2)
118-123
: Calibrate without grads and restore mode.Avoid unnecessary autograd and keep model mode intact.
Apply this diff:
- def calibrate(m: nn.Module): - for batch in calib_dataloader: - m(batch["input_ids"].to(device)) + def calibrate(m: nn.Module): + was_training = m.training + m.eval() + with torch.no_grad(): + for batch in calib_dataloader: + m(batch["input_ids"].to(device)) + if was_training: + m.train()
109-117
: Nit: avoid double .cuda()/.to(device).Define device before model creation and move once.
Apply this diff:
- # Load model and initialize loss - model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda() - tokenizer = AutoTokenizer.from_pretrained(args.model_path) - - # Get dataloaders - train_dataloader, calib_dataloader = get_dataloader(args, tokenizer) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Load model and initialize loss + model = AutoModelForCausalLM.from_pretrained(args.model_path).to(device) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + # Get dataloaders + train_dataloader, calib_dataloader = get_dataloader(args, tokenizer) @@ - model.to(device) + # already moved aboveAlso applies to: 130-132
examples/llm_qat/main.py (1)
190-191
: Typo: stray characters in comment.Remove the trailing “åå”.
Apply this diff:
- # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.åå + # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.modelopt/torch/quantization/utils.py (1)
459-466
: Make restore tolerant; early-exit on empty dict.Loading with strict=True can fail on minor shape/version drift; also no-op fast path helps. Suggest:
-def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict): +def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict, strict: bool = True): """Set the state dict of the quantizers in the model.""" from .nn import TensorQuantizer + if not quantizer_state_dict: + return for name, module in model.named_modules(): key = get_unwrapped_name(name, model) if isinstance(module, TensorQuantizer) and key in quantizer_state_dict: - module.load_state_dict(quantizer_state_dict[key]) + module.load_state_dict(quantizer_state_dict[key], strict=strict)examples/llm_qat/launch.sh (3)
52-55
: Print the invalid option, not its value.Current message shows ${1#*=} which is the value. Use $1.
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument %s\n" "$1"
88-93
: Quote variables in tests to avoid word-splitting.Prevent surprises if QUANT_CFG contains spaces.
-if [ -z $QUANT_CFG ]; then +if [ -z "${QUANT_CFG:-}" ]; then QUANT_ARGS="" else - QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" + QUANT_ARGS="--quant_cfg \"$QUANT_CFG\" --calib_size \"$CALIB_SIZE\"" fi
94-99
: Quote MAX_STEPS check.Avoid unary operator errors when unset.
-if [ ! -z $MAX_STEPS ]; then +if [ -n "${MAX_STEPS:-}" ]; then OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" fimodelopt/torch/quantization/plugins/transformers_trainer.py (1)
202-208
: Unused param: clarify intent of forward_loop’s model arg.Rename to underscore to avoid confusion with self.model.
- def forward_loop(model): + def forward_loop(_): for batch in tqdm(data_loader, desc="Calibrating"): batch = self._prepare_inputs(batch) # Important: We should forward pass using the unwrapped model # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop self.model(**batch)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
examples/llm_qat/accelerate_config/deepspeed.yaml
(1 hunks)examples/llm_qat/accelerate_config/fsdp1.yaml
(1 hunks)examples/llm_qat/convert_sharded_ckpt.py
(0 hunks)examples/llm_qat/launch.sh
(4 hunks)examples/llm_qat/main.py
(2 hunks)examples/llm_qat/simple_qat_train.py
(2 hunks)examples/llm_qat/utils.py
(1 hunks)modelopt/torch/opt/conversion.py
(2 hunks)modelopt/torch/opt/dynamic.py
(1 hunks)modelopt/torch/opt/plugins/peft.py
(2 hunks)modelopt/torch/quantization/calib/histogram.py
(1 hunks)modelopt/torch/quantization/conversion.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_module.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)modelopt/torch/quantization/utils.py
(2 hunks)modelopt/torch/utils/network.py
(3 hunks)tests/_test_utils/examples/run_command.py
(1 hunks)tests/examples/llm_qat/test_llm_qat.py
(3 hunks)
💤 Files with no reviewable changes (1)
- examples/llm_qat/convert_sharded_ckpt.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T16:36:42.871Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: examples/llm_qat/accelerate_config/deepspeed.yaml:17-17
Timestamp: 2025-09-15T16:36:42.871Z
Learning: In Hugging Face Accelerate configuration YAML files, num_processes can accept the string "gpu" as a value, which gets passed through to torch run under the hood and functions correctly, despite standard documentation showing integer values.
Applied to files:
examples/llm_qat/accelerate_config/deepspeed.yaml
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (8)
modelopt/torch/opt/dynamic.py (1)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)
modelopt/torch/quantization/conversion.py (4)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
set_from_modelopt_state
(1122-1140)get_modelopt_state
(1105-1120)get_modelopt_state
(1246-1248)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/quantization/nn/modules/quant_module.py (1)
QuantModule
(37-114)
modelopt/torch/quantization/utils.py (3)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(62-1182)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)
tests/examples/llm_qat/test_llm_qat.py (1)
tests/examples/conftest.py (1)
tiny_llama_path
(33-41)
examples/llm_qat/main.py (3)
examples/llm_qat/utils.py (1)
get_metrics_with_perplexity
(172-175)modelopt/torch/quantization/plugins/transformers_trainer.py (1)
evaluate
(240-247)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
unwrap_model
(430-454)
modelopt/torch/opt/plugins/peft.py (2)
modelopt/torch/quantization/utils.py (1)
get_quantizer_state_dict
(446-456)modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (4)
calibrate_with_adapters
(275-286)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/conversion.py (3)
modelopt_state
(444-486)save
(489-507)restore_from_modelopt_state
(510-567)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (19)
examples/llm_qat/accelerate_config/deepspeed.yaml (1)
1-23
: DS config looks good; num_processes: gpu accepted per project practiceAcknowledging prior learning in this repo that
"gpu"
is passed through to torch run and works as intended.If you want, I can add a quick check in CI to print the resolved launch args from Accelerate for visibility.
modelopt/torch/utils/network.py (1)
605-609
: Name cleaning logic: OKStripping "module." for DP/DDP/DS and then delegating to FSDP’s utility is correct.
tests/_test_utils/examples/run_command.py (1)
35-44
: Pass setup_free_port=True for distributed example runsAdd setup_free_port=True to run_example_command calls that launch multi-process/backends (accelerate launch, ./launch.sh, torchrun).
- tests/examples/llm_distill/test_llm_distill.py:23 (accelerate launch --multi_gpu)
- tests/examples/llm_qat/test_llm_qat.py: update _run_command (line 23) and its callers (lines 50, 61, 78, 89, 102, 103)
- tests/examples/speculative_decoding/test_eagle.py:22 (./launch.sh)
- tests/examples/speculative_decoding/test_medusa.py:33, 51, 73 (./launch.sh calls)
modelopt/torch/opt/conversion.py (2)
383-383
: Force-unwrapping in apply_mode — good call.Ensures consistent behavior across wrappers before conversion.
590-592
: Explicit unwrap on restore — correct.Rejecting wrapped models up-front improves clarity and avoids subtle failures.
examples/llm_qat/accelerate_config/fsdp1.yaml (1)
7-7
: Enabling FSDP activation checkpointing — verify interaction with training args.Ensure
TrainingArguments.gradient_checkpointing=True
(withuse_reentrant
) and this setting do not double-apply or conflict for the chosen transformers/accelerate versions in CI.examples/llm_qat/simple_qat_train.py (2)
88-93
: CLI default now a string name — OK if attributes exist on mtq.
choices=mtq.config.choices
should match attribute names. Looks fine.
124-125
: Using getattr(mtq, args.quant_cfg) — correct given string default.No issues.
modelopt/torch/quantization/conversion.py (1)
124-133
: State keying now model‑contextual — LGTM.Passing
model
intoget_unwrapped_name
aligns save/restore across wrappers.Also applies to: 166-172
examples/llm_qat/main.py (1)
259-269
: Eval path simplification with perplexity — LGTM.Straightforward and clearer. Uses rank‑0 logging properly.
tests/examples/llm_qat/test_llm_qat.py (1)
36-44
: Backend‑parameterized tests — nice coverage; watch runtime.Great to exercise fsdp1/fsdp2/deepspeed/ddp. If CI time grows, consider marking deepspeed as nightly/slow.
Also applies to: 45-58, 70-86
modelopt/torch/opt/plugins/peft.py (2)
60-66
: Centralized quantizer state save — good refactor.Avoids
state_dict()
pitfalls under FSDP and keeps logic in one place.
91-94
: Load with model‑context unwrapped names — correct.Indexes match saved dict under wrappers.
examples/llm_qat/launch.sh (4)
21-25
: Argument parser looks good.Handles both --arg=value and --arg value forms cleanly.
99-105
: Deprecation path — LGTM.Clear message; keeps backward-compat for --use_fsdp2.
110-136
: Backend switch is tidy.Centralized selection and GC args per backend read well.
138-146
: Scoped FSDP RAM-efficient toggle — LGTM.Only applied for fsdp1/fsdp2 when distill, which matches the note.
modelopt/torch/quantization/plugins/transformers_trainer.py (2)
251-259
: Final FULL_STATE_DICT toggle — LGTM.Good safeguard before the last checkpoint write.
163-169
: Verify no HF/Accelerate prepare/load runs before restore (FSDP2/DDP/DeepSpeed)
- QATTrainer patches accelerate.prepare then immediately restores modelopt state — modelopt/torch/quantization/plugins/transformers_trainer.py:L163-L169.
- ModelOptHFTrainer.init calls enable_huggingface_checkpointing() then super().init (HF Trainer) — modelopt/torch/opt/plugins/transformers.py:L156-L159; check transformers.Trainer.init (for your transformers + accelerate versions) for any accelerator.prepare(), torch.load()/from_pretrained() or checkpoint hooks that run during init (these would run before the QAT restore).
- If unsure, reproduce with a minimal run that logs accelerator.prepare / torch.load invocation order under your target backends.
ab384f0
to
c2eb5e5
Compare
1b14ede
to
7d612fc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
examples/llm_qat/launch.sh (1)
63-66
: Avoid divide-by-zero and ensure save_steps ≥ 1.GPU_COUNT can be 0 (CPU runs) or large enough to make 192/GPU_COUNT evaluate to 0. Guard and clamp to at least 1.
Apply:
-# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +# Calculate save_steps (guard 0 GPUs and clamp to ≥1) +if [ "${GPU_COUNT:-0}" -gt 0 ]; then + DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) + if [ "$DEFAULT_SAVE_STEPS" -lt 1 ]; then DEFAULT_SAVE_STEPS=1; fi +else + DEFAULT_SAVE_STEPS=192 +fiexamples/llm_qat/simple_qat_train.py (2)
109-117
: Unconditional .cuda() will crash on CPU-only; move device selection earlier and use .to(device).Load, then move model to the selected device.
- # Load model and initialize loss - model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda() - tokenizer = AutoTokenizer.from_pretrained(args.model_path) - # Get dataloaders - train_dataloader, calib_dataloader = get_dataloader(args, tokenizer) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Load model and tokenizer + model = AutoModelForCausalLM.from_pretrained(args.model_path).to(device) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + # Get dataloaders + train_dataloader, calib_dataloader = get_dataloader(args, tokenizer)
118-125
: Calibrate without gradients and with eval mode to save memory and time.Wrap calibration in inference_mode and temporarily switch to eval().
- def calibrate(m: nn.Module): - for batch in calib_dataloader: - m(batch["input_ids"].to(device)) + def calibrate(m: nn.Module): + was_training = m.training + m.eval() + with torch.inference_mode(): + for batch in calib_dataloader: + m(batch["input_ids"].to(device)) + if was_training: + m.train()modelopt/torch/quantization/plugins/transformers_trainer.py (1)
353-359
: export_student path doesn’t actually save the exported model under FSDP2 SHARDED_STATE_DICT.You export into a local var and then call QATTrainer.save_model, which saves self.model (not the exported one). Swap in the exported model temporarily or save it directly.
- if export_student: - model = self.accelerator.unwrap_model(self.model) - model = model.export() - return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs) + if export_student: + exported = self.accelerator.unwrap_model(self.model).export() + save_dir = output_dir or self.args.output_dir + exported.save_pretrained(save_dir) + return + return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
203-210
: Use of self.model in forward_loop is correct here.Per maintainer guidance for this flow, forwarding via self.model (not the unwrapped param) is intended.
🧹 Nitpick comments (8)
examples/llm_qat/launch.sh (3)
88-92
: Quote QUANT_CFG in tests to avoid word-splitting/unset issues.Use parameter expansion with default to prevent “[ -z ]” pitfalls when unset.
-if [ -z $QUANT_CFG ]; then +if [ -z "${QUANT_CFG:-}" ]; then QUANT_ARGS="" else QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" fi
95-97
: Quote MAX_STEPS check.Minor shell robustness for unset values.
-if [ ! -z $MAX_STEPS ]; then +if [ -n "${MAX_STEPS:-}" ]; then OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" fi
52-55
: Fix invalid-arg error message.Currently prints only the value part; print the flag itself.
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument %s\n" "$1"examples/llm_qat/simple_qat_train.py (1)
46-51
: Optional: tune DataLoader for GPU throughput.Pin memory and allow workers to reduce host-device stalls.
- train_dataloader = DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn - ) - calib_dataloader = DataLoader( - calib_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn - ) + common_dl = dict(collate_fn=collate_fn, pin_memory=torch.cuda.is_available(), num_workers=2) + train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **common_dl) + calib_dataloader = DataLoader(calib_dataset, batch_size=args.batch_size, shuffle=False, **common_dl)modelopt/torch/quantization/plugins/transformers_trainer.py (4)
241-249
: Eval-only FSDP2 workaround: add safe param access.next(self.model.parameters()) can raise StopIteration for paramless wrappers; guard defensively.
- dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0) + first_param = next((p for p in self.model.parameters()), None) + if first_param is None: + return super().evaluate(*args, **kwargs) + dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)
250-261
: Robust check for FULL_STATE_DICT.State-dict type may be an enum/obj; compare via string to avoid false negatives.
- if ( - (not self.is_in_train) - and self.is_fsdp_enabled - and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT" - ): + if ( + (not self.is_in_train) + and self.is_fsdp_enabled + and "FULL_STATE_DICT" not in str(self.accelerator.state.fsdp_plugin.state_dict_type) + ):
262-294
: Optional: guard against double-patching accelerate.prepare.Avoid overwriting _original_prepare if already patched.
- self.accelerator._original_prepare = self.accelerator.prepare + if getattr(self.accelerator, "_original_prepare", None) is None: + self.accelerator._original_prepare = self.accelerator.prepare self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
321-323
: Avoid unconditional .cuda() in QAD flow.Either assert CUDA availability or move to device with fallback to improve error messaging.
- self.model.cuda() + if not torch.cuda.is_available(): + raise RuntimeError("QAD requires CUDA; no GPU detected.") + self.model.cuda()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
examples/llm_qat/README.md
(0 hunks)examples/llm_qat/accelerate_config/deepspeed.yaml
(1 hunks)examples/llm_qat/accelerate_config/fsdp1.yaml
(1 hunks)examples/llm_qat/convert_sharded_ckpt.py
(0 hunks)examples/llm_qat/launch.sh
(4 hunks)examples/llm_qat/main.py
(2 hunks)examples/llm_qat/simple_qat_train.py
(2 hunks)examples/llm_qat/utils.py
(1 hunks)modelopt/torch/opt/conversion.py
(2 hunks)modelopt/torch/opt/dynamic.py
(1 hunks)modelopt/torch/opt/plugins/peft.py
(2 hunks)modelopt/torch/quantization/calib/histogram.py
(1 hunks)modelopt/torch/quantization/conversion.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_module.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)modelopt/torch/quantization/utils.py
(2 hunks)modelopt/torch/utils/network.py
(3 hunks)tests/_test_utils/examples/run_command.py
(1 hunks)tests/examples/llm_qat/test_llm_qat.py
(3 hunks)
💤 Files with no reviewable changes (2)
- examples/llm_qat/README.md
- examples/llm_qat/convert_sharded_ckpt.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/quantization/calib/histogram.py
🚧 Files skipped from review as they are similar to previous changes (13)
- examples/llm_qat/accelerate_config/fsdp1.yaml
- examples/llm_qat/utils.py
- tests/_test_utils/examples/run_command.py
- examples/llm_qat/accelerate_config/deepspeed.yaml
- modelopt/torch/quantization/utils.py
- modelopt/torch/opt/conversion.py
- modelopt/torch/opt/dynamic.py
- tests/examples/llm_qat/test_llm_qat.py
- modelopt/torch/quantization/conversion.py
- modelopt/torch/opt/plugins/peft.py
- examples/llm_qat/main.py
- modelopt/torch/quantization/nn/modules/quant_module.py
- modelopt/torch/utils/network.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (5)
calibrate_with_adapters
(275-286)disable_lora_quantizers_in_config
(289-296)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/conversion.py (3)
modelopt_state
(444-486)save
(489-507)restore_from_modelopt_state
(510-567)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (8)
examples/llm_qat/launch.sh (3)
99-108
: Backward-compat for --use_fsdp2 and compress→ddp logic looks good.Clear precedence and explicit deprecation notice. LGTM.
110-136
: Backend switch-case is clean and explicit.Good normalization to fsdp1/fsdp, fsdp2, ddp, deepspeed, and targeted args. LGTM.
139-146
: FSDP memory-efficient loading toggle scoped to fsdp backends only.Good constraint; avoids impacting ddp/deepspeed. LGTM.
examples/llm_qat/simple_qat_train.py (1)
88-93
: CLI default-to-string with getattr is correct.Switch to string default plus getattr aligns with choices and keeps args portable. LGTM.
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
163-169
: Good: restore ModelOpt state early if present.Early restore path is aligned with the PR goal (fix FSDP2 and QLoRA interactions). LGTM.
195-201
: Calibration dataset selection fix looks correct.Select dataset before len() avoids None handling pitfalls. LGTM.
211-218
: Quantize → optional compress flow is sensible.Order and gating are appropriate. LGTM.
229-234
: Lazy quantization trigger on first step is fine.Ensures quantize runs once. LGTM.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #318 +/- ##
==========================================
- Coverage 73.86% 73.85% -0.02%
==========================================
Files 172 172
Lines 17409 17430 +21
==========================================
+ Hits 12860 12873 +13
- Misses 4549 4557 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
353-359
: export_student is ignored under FSDP2; exported model never savedLocal
model = ...export()
is unused andQATTrainer.save_model
doesn’t acceptexport_student
, so the student export is effectively dropped. Temporarily swapself.model
to the exported student for saving and restore afterward.- if self.accelerator.is_fsdp2 and "SHARDED_STATE_DICT" in str( - self.accelerator.state.fsdp_plugin.state_dict_type - ): - if export_student: - model = self.accelerator.unwrap_model(self.model) - model = model.export() - return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs) + if self.accelerator.is_fsdp2 and "SHARDED_STATE_DICT" in str( + self.accelerator.state.fsdp_plugin.state_dict_type + ): + original_model = self.model + try: + if export_student: + exported = self.accelerator.unwrap_model(self.model).export() + self.model = exported + return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs) + finally: + self.model = original_model
♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)
197-201
: LGTM: dataset selection order fixedSelecting
dataset
before callinglen(dataset)
avoids None errors previously flagged.
278-289
: to_empty() can wipe restored quantizer buffers; snapshot/restore around prepareCalling
to_empty()
before FSDP2 prepare can drop quantizer buffer storage (problematic for eval-only flows after restoring state). Snapshot quantizer state before the loop and restore it afterprepare
.tq_og_non_prsist_buffers = {} + saved_tq_state = get_quantizer_state_dict(model) for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): tq.to_empty(device=self.device) tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy() tq._non_persistent_buffers_set.update(tq._buffers.keys()) @@ outputs = self._original_prepare(*args, **kwargs) for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): tq._non_persistent_buffers_set.clear() tq._non_persistent_buffers_set = tq_og_non_prsist_buffers[tq] + set_quantizer_state_dict(model, saved_tq_state)Optional nit: fix the variable name typo for readability.
- tq_og_non_prsist_buffers = {} + tq_og_non_persistent_buffers = {} @@ - tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy() + tq_og_non_persistent_buffers[tq] = tq._non_persistent_buffers_set.copy() @@ - tq._non_persistent_buffers_set = tq_og_non_prsist_buffers[tq] + tq._non_persistent_buffers_set = tq_og_non_persistent_buffers[tq]
🧹 Nitpick comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
253-260
: StateDictType comparison is brittle; string vs enum mismatch
state_dict_type != "FULL_STATE_DICT"
will always be True if it’s an enum. Use string containment (as done elsewhere) or compare to the enum.- and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT" + and "FULL_STATE_DICT" not in str(self.accelerator.state.fsdp_plugin.state_dict_type)
292-293
: Avoid double‑patching Accelerator.prepareGuard
_original_prepare
assignment to prevent recursion if patched more than once.- self.accelerator._original_prepare = self.accelerator.prepare - self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator) + if not hasattr(self.accelerator, "_original_prepare"): + self.accelerator._original_prepare = self.accelerator.prepare + self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
203-210
: Forward loop param unused; mute tqdm on non‑master and clarify commentRename the unused param and avoid multi‑process progress spam. Keep using
self.model(**batch)
as intended.- def forward_loop(model): - for batch in tqdm(data_loader, desc="Calibrating"): + def forward_loop(_): + for batch in tqdm( + data_loader, desc="Calibrating", disable=not self.accelerator.is_local_main_process + ): batch = self._prepare_inputs(batch) - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop + # Intentionally call self.model; quantize() manages unwrapping internally. self.model(**batch)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (5)
calibrate_with_adapters
(275-286)disable_lora_quantizers_in_config
(289-296)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/conversion.py (3)
modelopt_state
(444-486)save
(489-507)restore_from_modelopt_state
(510-567)modelopt/torch/quantization/model_quant.py (2)
forward_loop
(95-96)quantize
(132-227)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
206-208
: LGTM: forward pass uses self.model per intended designThis aligns with the agreed behavior for the calibration loop.
158-162
: Confirm QuantizeConfig mapping semantics in AWQ/SmoothQuant check
quant_cfg.get("algorithm", {})
assumes dict‑like behavior. Ifquant_cfg
can be aQuantizeConfig
object, verify it supports.get()
; otherwise normalize to a dict first.
165-169
: LGTM: restore ModelOpt state earlyRestoring before any weight load matches the FSDP2/QLoRA requirement and unblocks the workflow.
Does the conversion from sharded_dict to full_dict for saving work for slightly larger model (may be 8B/14B), especially with QAD? |
self.model(**batch) | ||
|
||
# TODO: Remove calibrate_with_adpaters - this should not be needed | ||
with calibrate_with_adapters(self.model, self.args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sugunav14 why do we need to use calibrate_with_adapters
context manager? I doubt if it is truly needed.
LGTM, thanks for the fix! |
@kinjalpatel27 I have tested 8B model with QAT. Let me double check with QAD (I am assuming it should work since QAD exports the student model before saving). Could you please take a look at other things? |
@realAsma generally LGTM, two comments:
Will you be updating the readme in a later PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving as codeowner to unblock
fixed, there were some additional bugs in the script - fixed them too.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/llm_qat/launch.sh (1)
62-66
: Fix division-by-zero and clamp DEFAULT_SAVE_STEPS to ≥1.On CPU-only runs (
GPU_COUNT=0
) the arithmetic fails. With large GPU counts, result may be 0 (invalid).Apply:
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +# Calculate save_steps safely +if ! [[ "$GPU_COUNT" =~ ^[0-9]+$ ]] || [ "$GPU_COUNT" -lt 1 ]; then + GPU_COUNT=1 +fi +DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +if [ "$DEFAULT_SAVE_STEPS" -lt 1 ]; then + DEFAULT_SAVE_STEPS=1 +fi
🧹 Nitpick comments (10)
examples/llm_qat/launch.sh (6)
29-52
: Tighten option patterns or guard for collisions.Patterns like
--model*
can match unintended flags (e.g.,--model_name...
). Consider exact matches to avoid accidental captures.Example:
- --model*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --model|--model=*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;Repeat for other options as needed.
88-98
: Quote variables in tests to avoid word-splitting/globbing.Unquoted vars in
[ -z $QUANT_CFG ]
and similar can misbehave.Apply:
-if [ -z $QUANT_CFG ]; then +if [ -z "${QUANT_CFG:-}" ]; then QUANT_ARGS="" else QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" fi OPTIONAL_ARGS="" -if [ ! -z $MAX_STEPS ]; then +if [ -n "${MAX_STEPS:-}" ]; then OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" fi
99-107
: Override to ddp on --compress: add an explicit warning.You silently override a user-selected backend. Surface a warning to avoid surprise.
Apply:
-if [[ "${COMPRESS,,}" == "true" ]]; then - BACKEND="ddp" +if [[ "${COMPRESS,,}" == "true" ]]; then + echo "Info: --compress enabled; forcing --backend=ddp (FSDP not supported with compression)." >&2 + BACKEND="ddp" fi
148-179
: Quote CLI argument values in CMD for safety.Paths or model ids with spaces/shell metacharacters can break the command.
Apply:
-CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \ +CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \ main.py \ - --model_name_or_path $MODEL \ - --model_max_length $MAX_SEQ_LENGTH \ + --model_name_or_path \"$MODEL\" \ + --model_max_length \"$MAX_SEQ_LENGTH\" \ --dataloader_drop_last True \ - --do_train $DO_TRAIN \ + --do_train \"$DO_TRAIN\" \ --do_eval True \ - --output_dir $OUTPUT_DIR \ - --dataset $DATASET \ - --train_size $TRAIN_SIZE \ - --eval_size $EVAL_SIZE \ - --num_train_epochs $NUM_EPOCHS \ - --per_device_train_batch_size $TRAIN_BS \ - --per_device_eval_batch_size $EVAL_BS \ - --gradient_accumulation_steps $ACCUM_STEPS \ + --output_dir \"$OUTPUT_DIR\" \ + --dataset \"$DATASET\" \ + --train_size \"$TRAIN_SIZE\" \ + --eval_size \"$EVAL_SIZE\" \ + --num_train_epochs \"$NUM_EPOCHS\" \ + --per_device_train_batch_size \"$TRAIN_BS\" \ + --per_device_eval_batch_size \"$EVAL_BS\" \ + --gradient_accumulation_steps \"$ACCUM_STEPS\" \ --eval_accumulation_steps 1 \ --save_strategy steps \ - --save_steps $SAVE_STEPS \ + --save_steps \"$SAVE_STEPS\" \ --eval_strategy steps \ - --eval_steps $SAVE_STEPS \ + --eval_steps \"$SAVE_STEPS\" \ --load_best_model_at_end True \ --save_total_limit 2 \ - --learning_rate $LR \ + --learning_rate \"$LR\" \ --weight_decay 0.0 \ --warmup_ratio 0.1 \ --lr_scheduler_type linear \ --logging_steps 1 \ --report_to tensorboard \ - --lora $LORA \ - --compress $COMPRESS \ + --lora \"$LORA\" \ + --compress \"$COMPRESS\" \ $GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS "
53-55
: Minor: error message should print the offending token verbatim.
${1#*=}
can mangle strings; use$1
.Apply:
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument %s\n" "$1"
60-60
: Trace mode: keep or gate by VERBOSE.
set -x
is useful for CI but noisy for users. Consider gating onVERBOSE=1
.Example:
- set -x + [[ "${VERBOSE:-0}" == "1" ]] && set -xmodelopt/torch/quantization/plugins/transformers_trainer.py (4)
199-201
: Add error handling for missing state file.The method calls
restore_modelopt_state_with_weights
without checking if the file exists, which could raise aFileNotFoundError
. This is inconsistent with the check performed in__init__
at lines 177-178.Apply this diff to add error handling:
def _restore_modelopt_state_with_weights(self): + if not os.path.exists(self._modelopt_state_path): + print_rank_0(f"ModelOpt state file not found: {self._modelopt_state_path}") + return restore_modelopt_state_with_weights(self.model, self._modelopt_state_path) print_rank_0("Restored modelopt state with weights.")
218-219
: Consider removing the TODO forcalibrate_with_adapters
.The TODO comment questions whether
calibrate_with_adapters
is needed, but since the PR objectives mention QLoRA support and testing, this context manager is likely required to properly disable LoRA adapters during calibration. Consider either removing the TODO or clarifying why it might not be needed in the future.
340-345
: Add memory checks before moving model to GPU.The
self.model.cuda()
call at line 339 could fail for large models that exceed GPU memory. Consider adding a try-catch or memory availability check, especially since the comment mentions that "memory efficient loading doesn't work" for QAD.Apply this diff to add error handling:
- self.model.cuda() + try: + self.model.cuda() + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + "Failed to move model to GPU. QAD requires the entire model to fit in GPU memory. " + "Consider using QAT instead or reducing model/batch size." + ) from e
406-410
: Add error handling for checkpoint loading failures.The
dist_cp.load_state_dict
call could fail if the checkpoint is corrupted or incompatible. Consider adding error handling to provide a clearer error message.Apply this diff to add error handling:
- dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader=dist_cp.FileSystemReader(sharded_model_path), - no_dist=True, - ) + try: + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=dist_cp.FileSystemReader(sharded_model_path), + no_dist=True, + ) + except Exception as e: + print_rank_0(f"Failed to load sharded checkpoint from {sharded_model_path}: {e}") + raise
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/llm_qat/convert_sharded_ckpt.py
(0 hunks)examples/llm_qat/launch.sh
(4 hunks)examples/llm_qat/utils.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(7 hunks)
💤 Files with no reviewable changes (1)
- examples/llm_qat/convert_sharded_ckpt.py
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_qat/utils.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (5)
calibrate_with_adapters
(275-286)disable_lora_quantizers_in_config
(289-296)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/conversion.py (4)
modelopt_state
(444-486)save
(489-507)state_dict
(130-132)load_state_dict
(139-161)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing
(127-162)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (7)
examples/llm_qat/launch.sh (2)
110-137
: Backend case: LGTM with minor polish.Case handling is clear, includes fsdp↔fsdp1 alias and deprecation note for fsdp2.
Please confirm accelerate_config/{fsdp1.yaml,fsdp2.yaml,ddp.yaml,deepspeed.yaml} exist in this PR to prevent runtime failures.
138-146
: FSDP memory-efficient loading disable under distill: scoped correctly.Scoping to fsdp1/fsdp2 only is appropriate.
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
182-197
: LGTM! Clean separation of state and weights for FSDP2 compatibility.The save method now correctly stores both the filtered ModelOpt state dict and quantizer weights separately, which fixes the FSDP2 workflow issues mentioned in the PR objectives.
203-231
: LGTM! Comprehensive quantization flow with dataset validation.The quantization method properly validates dataset availability, creates a calibration subset, and saves the state after quantization. The integration with
calibrate_with_adapters
ensures LoRA adapters are correctly handled during calibration.
384-416
: LGTM! Well-structured checkpoint conversion utility.The new
convert_sharded_model_to_hf_format
function properly handles the conversion from FSDP sharded checkpoints to HuggingFace format, addressing the checkpoint compatibility issues mentioned in the PR objectives.
250-255
: Verify FSDP2 eval-only workaround doesn't affect training mode.Search found the dummy-optimizer hack only at modelopt/torch/quantization/plugins/transformers_trainer.py:250-255; ensure overwriting self.model via accelerator.prepare(dummy_optimizer) in the eval-only branch cannot leave the model in an FSDP-wrapped/prepared state that would interfere with later training — use a locally prepared model for evaluation or re-prepare the model with the real training optimizer before any subsequent training.
173-180
: Ensure _patch_accelerate_for_fsdp2_fix() runs before any Accelerate prepare callsThe method monkey‑patches accelerator.prepare to hide quantizer buffers — it must execute before any call to .prepare() (including calls from parent class init or plugins). Verify now by inspecting modelopt/torch/quantization/plugins/transformers_trainer.py (around lines 173–180) and confirming the patch call happens before any super().init/plugin initialization that could invoke accelerator.prepare; if a prepare() can run earlier, move the patch to the very start of init (or earlier) so it always precedes any .prepare() invocation. Automated sandbox checks failed; please run a repo search for ".prepare(" and confirm call order.
069b244
to
7551732
Compare
7551732
to
b605d8c
Compare
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.") | ||
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") | ||
outputs = super().save_model(*args, **kwargs) | ||
torch.distributed.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kinjalpatel27 Converting FSDP2 to use FULL_STATE_DICT works with both QAT and QAD on larger models after inserting this distributed barrier (the problem was while rank0 was still saving the model, the other ranks had exited the program).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/llm_qat/launch.sh (1)
63-66
: Guard division by zero when no GPUs are presentGPU_COUNT can be 0, causing DEFAULT_SAVE_STEPS=$((192 / 0)) to fail.
Apply:
-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())" 2>/dev/null || echo 0) +# Calculate save_steps (fallback to 192 on CPU/no CUDA) +if (( GPU_COUNT > 0 )); then + DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +else + >&2 echo "Warning: No GPUs detected; using DEFAULT_SAVE_STEPS=192" + DEFAULT_SAVE_STEPS=192 +fi
♻️ Duplicate comments (1)
examples/llm_qat/launch.sh (1)
21-25
: Fix arg parsing: avoid global shift + validate missing valuesCurrent parse_value shifts the global positional params (double-shifts with the caller) and accepts a following flag as a value. Replace with a non-shifting, validating version.
Apply:
-# Helper function to parse a single argument value -parse_value() { - if [[ "$1" != *=* ]]; then shift; fi - echo "${1#*=}" -} +# Helper: extract value from "--opt=value" or "--opt value"; fail if missing (no global shift) +parse_value() { + local first="$1" + local next="${2-}" + if [[ "$first" == *=* ]]; then + echo "${first#*=}" + return 0 + fi + if [[ -z "$next" || "$next" == --* || "$next" == -* ]]; then + >&2 echo "Error: Missing value for option '$first'" + exit 2 + fi + echo "$next" +}
🧹 Nitpick comments (18)
examples/llm_qat/launch.sh (6)
29-57
: Correct invalid-arg error message to show the flag, not the valuePrinting ${1#*=} hides the flag when using --opt=value.
Apply:
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument %s\n" "$1"
88-97
: Quote tests; prefer [[ -n ]] to avoid word-splitting and edge casesUnquoted vars may misbehave when empty or starting with '-'.
Apply:
-if [ -z $QUANT_CFG ]; then +if [[ -z "$QUANT_CFG" ]]; then QUANT_ARGS="" else QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" fi OPTIONAL_ARGS="" -if [ ! -z $MAX_STEPS ]; then +if [[ -n "$MAX_STEPS" ]]; then OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" fi
99-108
: Warn when COMPRESS overrides user-selected BACKENDCurrently COMPRESS silently forces ddp. Emit an explicit notice.
Apply:
# if compress is true, set backend to ddp if [[ "${COMPRESS,,}" == "true" ]]; then - BACKEND="ddp" + if [[ "${BACKEND,,}" != "ddp" ]]; then + echo "Info: --compress enabled; overriding backend '$BACKEND' -> 'ddp'" + fi + BACKEND="ddp" fi
148-179
: Quote high-risk args in CMD to handle spaces safelyModel path, output dir, and dataset may contain spaces; quote them.
Apply:
- --model_name_or_path $MODEL \ + --model_name_or_path "$MODEL" \ @@ - --output_dir $OUTPUT_DIR \ + --output_dir "$OUTPUT_DIR" \ - --dataset $DATASET \ + --dataset "$DATASET" \Optional follow-up: build CMD as an array to avoid word-splitting entirely.
181-183
: Ensure timing prints even on failureUse a trap so elapsed time is reported when the run exits early.
Apply:
-start_time=$(date +%s) -sh -c "$CMD" -echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" +start_time=$(date +%s) +trap 'echo "Total time taken: $(( $(date +%s) - start_time )) seconds"' EXIT +sh -c "$CMD"
60-60
: Gate shell tracing behind DEBUGset -x can leak sensitive info; make it opt-in.
Apply:
-set -x +[[ "${DEBUG:-0}" == "1" ]] && set -xmodelopt/torch/quantization/conversion.py (2)
115-123
: Avoid recomputing quantizer_state(model) and tighten mismatch reporting.Compute it once to reduce traversal overhead and provide clearer errors.
- quantizer_state_dict = metadata["quantizer_state"] - unmatched_keys = quantizer_state_dict.keys() - quantizer_state(model).keys() - extra_keys = quantizer_state(model).keys() - quantizer_state_dict.keys() + quantizer_state_dict = metadata["quantizer_state"] + current_state = quantizer_state(model) + unmatched_keys = quantizer_state_dict.keys() - current_state.keys() + extra_keys = current_state.keys() - quantizer_state_dict.keys()Also applies to: 124-133
166-172
: Guard against potential unwrapped-name collisions.If two wrapped module paths normalize to the same unwrapped key, one state can overwrite the other silently. Consider asserting uniqueness (debug-only) or logging a warning.
def quantizer_state(model: nn.Module) -> dict[str, Any]: """Returns the quantizer state dict describing the quantizer states in the model.""" - return { + state = { get_unwrapped_name(n, model): m.get_modelopt_state() for n, m in model.named_modules() if isinstance(m, (TensorQuantizer, SequentialQuantizer)) } + # Optional: debug safeguard + # assert len(state) == len({get_unwrapped_name(n, model) for n, m in model.named_modules() + # if isinstance(m, (TensorQuantizer, SequentialQuantizer))}), \ + # "Duplicate unwrapped quantizer names detected." + return statemodelopt/torch/quantization/plugins/transformers_trainer.py (10)
163-170
: Early restore/save path looks good; ensure output_dir exists.Create the directory before save/restore to avoid surprises on fresh runs.
- self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth") + self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth") + os.makedirs(self.args.output_dir, exist_ok=True)
190-196
: Guard when modelopt_state_weights is absent.Older checkpoints may lack weights; avoid calling set_quantizer_state_dict with None.
- modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) restore_from_modelopt_state(self.model, modelopt_state) - set_quantizer_state_dict(self.model, modelopt_weights) + if modelopt_weights is not None: + set_quantizer_state_dict(self.model, modelopt_weights)
197-216
: Calibration loop: limit tqdm to main process.Prevents N progress bars under DDP/FSDP.
- def forward_loop(model): - for batch in tqdm(data_loader, desc="Calibrating"): + def forward_loop(model): + for batch in tqdm(data_loader, desc="Calibrating", + disable=not self.accelerator.is_main_process): batch = self._prepare_inputs(batch) # Important: We should forward pass using the unwrapped model # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop self.model(**batch)
212-213
: Typo in TODO.“adpaters” → “adapters”.
217-226
: Emptying CUDA cache: gate on CUDA availability.Avoids unnecessary call on CPU-only nodes.
- torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache()
242-249
: Eval-only FSDP2 hack: OK; add small safeguard.If model has no parameters (edge adapters), next(self.model.parameters()) raises StopIteration.
- dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0) + first_param = next(self.model.parameters(), None) + if first_param is None: + return super().evaluate(*args, **kwargs) + dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)
259-274
: Switching to FULL_STATE_DICT: consider restoring previous setting after save.Prevents persistent side-effects if more saves follow.
- self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") - outputs = super().save_model(*args, **kwargs) + prev = self.accelerator.state.fsdp_plugin.state_dict_type + self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + try: + outputs = super().save_model(*args, **kwargs) + finally: + self.accelerator.state.fsdp_plugin.set_state_dict_type(prev)
275-307
: Patch safety: avoid double-wrapping and restore non-persistent buffer set in-place.
- Don’t overwrite _original_prepare if already patched.
- Restore _non_persistent_buffers_set contents without reassigning the set (prevents reference invalidation).
- def _modelopt_prepare(self, *args, **kwargs): + def _modelopt_prepare(self, *args, **kwargs): if not self.is_fsdp2: return self._original_prepare(*args, **kwargs) - model = next((obj for obj in args if isinstance(obj, torch.nn.Module)), None) + model = next((obj for obj in args if isinstance(obj, torch.nn.Module)), None) + if model is None: + model = next((obj for obj in kwargs.values() if isinstance(obj, torch.nn.Module)), None) if model is None: return self._original_prepare(*args, **kwargs) tq_og_non_prsist_buffers = {} for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): - tq.to_empty(device=self.device) tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy() tq._non_persistent_buffers_set.update(tq._buffers.keys()) outputs = self._original_prepare(*args, **kwargs) for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): - tq._non_persistent_buffers_set.clear() - tq._non_persistent_buffers_set = tq_og_non_prsist_buffers[tq] + tq._non_persistent_buffers_set.clear() + tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq]) return outputs - self.accelerator._original_prepare = self.accelerator.prepare + if getattr(self.accelerator, "_original_prepare", None) is None: + self.accelerator._original_prepare = self.accelerator.prepare self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
291-296
: Re-evaluate calling to_empty() on TensorQuantizer during prepare.It can drop/restage buffers; you’re already hiding them via _non_persistent_buffers_set. If you keep it, gate it to meta-device tensors only.
366-375
: State-dict type check by string match is brittle.Prefer comparing the enum/constant directly if available.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
examples/llm_qat/README.md
(0 hunks)examples/llm_qat/accelerate_config/deepspeed.yaml
(1 hunks)examples/llm_qat/accelerate_config/fsdp1.yaml
(1 hunks)examples/llm_qat/convert_sharded_ckpt.py
(0 hunks)examples/llm_qat/launch.sh
(4 hunks)examples/llm_qat/llama_factory/launch_llamafactory.sh
(0 hunks)examples/llm_qat/main.py
(2 hunks)examples/llm_qat/simple_qat_train.py
(3 hunks)examples/llm_qat/utils.py
(1 hunks)modelopt/torch/opt/conversion.py
(2 hunks)modelopt/torch/opt/dynamic.py
(1 hunks)modelopt/torch/opt/plugins/peft.py
(2 hunks)modelopt/torch/quantization/calib/histogram.py
(1 hunks)modelopt/torch/quantization/conversion.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_module.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)modelopt/torch/quantization/utils.py
(2 hunks)modelopt/torch/utils/network.py
(3 hunks)tests/_test_utils/examples/run_command.py
(1 hunks)tests/examples/llm_qat/test_llm_qat.py
(3 hunks)
💤 Files with no reviewable changes (3)
- examples/llm_qat/README.md
- examples/llm_qat/convert_sharded_ckpt.py
- examples/llm_qat/llama_factory/launch_llamafactory.sh
🚧 Files skipped from review as they are similar to previous changes (14)
- examples/llm_qat/accelerate_config/fsdp1.yaml
- tests/examples/llm_qat/test_llm_qat.py
- modelopt/torch/opt/dynamic.py
- modelopt/torch/quantization/nn/modules/quant_module.py
- examples/llm_qat/utils.py
- modelopt/torch/opt/conversion.py
- modelopt/torch/quantization/calib/histogram.py
- examples/llm_qat/main.py
- examples/llm_qat/simple_qat_train.py
- modelopt/torch/opt/plugins/peft.py
- tests/_test_utils/examples/run_command.py
- modelopt/torch/quantization/utils.py
- examples/llm_qat/accelerate_config/deepspeed.yaml
- modelopt/torch/utils/network.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/quantization/conversion.py
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (2)
modelopt/torch/quantization/conversion.py (3)
modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
set_from_modelopt_state
(1122-1140)get_modelopt_state
(1105-1120)get_modelopt_state
(1246-1248)modelopt/torch/quantization/nn/modules/quant_module.py (1)
QuantModule
(37-114)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (5)
calibrate_with_adapters
(275-286)disable_lora_quantizers_in_config
(289-296)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/opt/conversion.py (3)
modelopt_state
(444-486)save
(489-507)restore_from_modelopt_state
(510-567)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
171-189
: Barrier + state filtering LGTM; minor: save blob shape documented?Consider documenting the saved schema (keys) for future compatibility/migrations.
230-241
: On-demand quantization trigger LGTM.The gating is correct and minimizes redundant work.
335-335
: Quantize-before-adding LoRA in QAD differs from QAT.Confirm intent: QAT adds LoRA before quantization, QAD adds after. Verify LoRA quantizers are correctly disabled for QAD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)
197-204
: LGTM: dataset selection bug fixed.Now selects the available dataset before computing length; avoids NoneType len() error.
275-306
: Avoid recursive accelerate.prepare patching (potential infinite recursion).On repeated calls, _original_prepare gets overwritten with the already‑patched prepare, making the wrapper call itself. Guard the assignment and add a one‑time patch flag.
def _patch_accelerate_for_fsdp2_fix(self): @@ - def _modelopt_prepare(self, *args, **kwargs): + def _modelopt_prepare(self, *args, **kwargs): if not self.is_fsdp2: return self._original_prepare(*args, **kwargs) @@ return outputs @@ - self.accelerator._original_prepare = self.accelerator.prepare - self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator) + # Patch only once and keep a stable pointer to the original prepare + if not getattr(self.accelerator, "_modelopt_prepare_patched", False): + if getattr(self.accelerator, "_original_prepare", None) is None: + self.accelerator._original_prepare = self.accelerator.prepare + self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator) + self.accelerator._modelopt_prepare_patched = True
🧹 Nitpick comments (4)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
291-296
: Nit: fix variable name and keep set object stable.Typo in variable name and you’re already restoring contents in-place (good). Rename for clarity.
- tq_og_non_prsist_buffers = {} + tq_orig_non_persist_buffers = {} @@ - tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy() + tq_orig_non_persist_buffers[tq] = tq._non_persistent_buffers_set.copy() @@ - tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq]) + tq._non_persistent_buffers_set.update(tq_orig_non_persist_buffers[tq])Also applies to: 299-302
205-211
: Calibrate under no_grad/inference_mode to cut memory/overhead.Forward passes for calibration don’t need grads.
- def forward_loop(model): - for batch in tqdm(data_loader, desc="Calibrating"): - batch = self._prepare_inputs(batch) - # Important: We should forward pass using the unwrapped model - # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop - self.model(**batch) + def forward_loop(model): + with torch.inference_mode(): + for batch in tqdm(data_loader, desc="Calibrating"): + batch = self._prepare_inputs(batch) + # Intentionally use self.model per HF/ModelOpt integration contract + self.model(**batch)
212-216
: Is calibrate_with_adapters needed? Gate or remove if redundant.If LoRA disabling isn’t required, consider removing or gating via a flag to avoid surprising behavior.
242-249
: Eval‑only FSDP2 prepare: guard models without parameters.next(self.model.parameters()) will raise on param‑less models. Use a safe fallback.
- dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0) + first_param = next(iter(self.model.parameters()), None) + if first_param is None: + first_param = torch.nn.Parameter(torch.zeros(1, device=self.model.device)) + dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (5)
calibrate_with_adapters
(275-286)disable_lora_quantizers_in_config
(289-296)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
217-226
: LGTM: post‑calibration housekeeping.Compression optionality, GC, CUDA cache clear, and persisting state look good.
230-235
: LGTM: on‑demand quantization before first train step.Simple and avoids repeated work once is_quantized() is true.
163-170
: Startup restore/save flow: good ordering.Restoring existing ModelOpt state before any weight loading and persisting post‑quantization aligns with the FSDP2/QLoRA fix.
Please confirm that all entry points which may construct a Trainer instance do so before any external weight loading occurs for FSDP2 paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
302-302
: Note: prior discussion on calling to_empty() stands.You previously confirmed this is intended; leaving as-is.
🧹 Nitpick comments (9)
modelopt/torch/quantization/plugins/transformers_trainer.py (9)
49-55
: Fix typos in docstring (clarity/polish).“This classes” → “This class”; “taining” → “training”.
- """Quantization arguments for quantization aware training. + """Quantization arguments for quantization-aware training. @@ - This classes is intended to be used with ModelOpt's QAT/QAD trainers for HuggingFace models. + This class is intended to be used with ModelOpt's QAT/QAD trainers for HuggingFace models. @@ - from the command line to the taining script. + from the command line to the training script.
92-101
: Make AWQ/SmoothQuant detection robust for dict or QuantizeConfig.check_awq_smoothquant assumes dict; support QuantizeConfig or objects with .algorithm.
-def check_awq_smoothquant(quant_cfg): +def check_awq_smoothquant(quant_cfg): @@ - algorithm = quant_cfg.get("algorithm", {}) + if hasattr(quant_cfg, "algorithm"): + algorithm = getattr(quant_cfg, "algorithm") or {} + elif isinstance(quant_cfg, dict): + algorithm = quant_cfg.get("algorithm", {}) or {} + else: + algorithm = {}Also applies to: 103-115
163-170
: Early FS-DP2 patch + state restore/save ordering looks right.Restoring ModelOpt states before weights and saving when already-quantized is correct for FSDP2/QLoRA flows. Consider ensuring output_dir exists.
- self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth") + os.makedirs(self.args.output_dir, exist_ok=True) + self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth")
171-189
: Save path: include barrier only when initialized; filtering KD/export state is fine.Minor nit: optional guard for availability; otherwise LGTM.
- if torch.distributed.is_initialized(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier()
222-227
: Guard CUDA cache emptying.Avoid calling empty_cache on CPU-only builds.
- torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache()
243-251
: Handle models with zero trainable parameters in eval-only FSDP2.next(self.model.parameters()) can raise StopIteration for fully-frozen/adapter-only models.
- dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0) + first_param = next(self.model.parameters(), None) + if first_param is None: + dummy = torch.nn.Parameter(torch.zeros(1, device=self.accelerator.device)) + dummy_optimizer = torch.optim.SGD([dummy], lr=0.0) + else: + dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)
261-283
: Save path: compare fsdp state_dict_type via str() for robustness.Some plugins use enums/objects; string compare avoids false negatives.
- and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT" + and str(self.accelerator.state.fsdp_plugin.state_dict_type) != "FULL_STATE_DICT"
284-316
: Make Accelerate prepare patch idempotent to avoid recursion on re‑patches.Guard _original_prepare assignment so multiple trainers don’t capture the wrapped function and recurse.
- self.accelerator._original_prepare = self.accelerator.prepare - self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator) + if getattr(self.accelerator, "_original_prepare", None) is None: + self.accelerator._original_prepare = self.accelerator.prepare + self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
300-311
: Nit: fix variable name and keep in‑place restore of non‑persistent buffer sets.Spelling and clarity; logic is good.
- tq_og_non_prsist_buffers = {} + tq_orig_non_persistent_buffers = {} @@ - tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy() + tq_orig_non_persistent_buffers[tq] = tq._non_persistent_buffers_set.copy() @@ - tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq]) + tq._non_persistent_buffers_set.update(tq_orig_non_persistent_buffers[tq])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-16T20:14:34.725Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:191-191
Timestamp: 2025-09-16T20:14:34.725Z
Learning: The TensorRT-Model-Optimizer project only supports PyTorch >= 2.6, so using the `weights_only` parameter in torch.load calls is acceptable and doesn't require backward compatibility handling.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
TensorQuantizer
(62-1182)collect
(1166-1176)modelopt/torch/quantization/utils.py (4)
calibrate_with_adapters
(275-286)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/opt/conversion.py (5)
modelopt_state
(444-486)save
(489-507)restore_from_modelopt_state
(510-567)ModeloptStateManager
(63-311)is_converted
(102-127)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
🔇 Additional comments (5)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
145-157
: LoRA adapter handling before quantization looks good.
190-197
: Restore logic is correct and backward-safe for missing weights.
198-205
: Dataset selection ordering fix is correct.Using a Subset of first calib_size examples is fine; using eval dataloader for a train dataset is acceptable for calibration, but verify collator differences don’t affect stats.
213-217
: Revisit calibrate_with_adapters necessity (and ensure it’s a contextmanager).If no functional benefit, consider removing to simplify; also confirm it’s decorated with @contextmanager.
229-230
: Quant summary only on main process — good practice.
…; Cleaned QATTrainer Signed-off-by: realAsma <[email protected]> bug fix Signed-off-by: realAsma <[email protected]> Fixed full_state_dict hang by inserting barrier Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]>
a9d4d93
to
c4bba4f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/llm_qat/launch.sh (2)
63-66
: Guard division by zero when no GPUs are visible.On CPU‑only nodes
torch.cuda.device_count()
returns 0 and the arithmetic expansion fails.-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) +GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") +if [[ "$GPU_COUNT" -le 0 ]]; then + >&2 echo "Warning: No GPUs detected; defaulting save_steps to 192" + GPU_COUNT=1 +fi +DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
148-152
: Move script-specific args aftermain.py
.Options like
--fsdp_transformer_layer_cls_to_wrap
are for the training script (main.py), not foraccelerate
; place$FSDP_ARGS
aftermain.py
to avoid "unrecognized argument" errors.File: examples/llm_qat/launch.sh (lines 148-152)
-CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \ - main.py \ +CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE \ + main.py $FSDP_ARGS \
♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
314-315
: Make accelerate.prepare patch idempotent to avoid recursion on repeated patching.If init runs twice on the same Accelerator, _original_prepare will be overwritten with the wrapper, causing infinite recursion.
Apply:
- self.accelerator._original_prepare = self.accelerator.prepare - self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator) + if getattr(self.accelerator, "_original_prepare", None) is None: + self.accelerator._original_prepare = self.accelerator.prepare + self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)examples/llm_qat/launch.sh (1)
21-25
: Harden parse_value to prevent option bleed‑through (missing value).Current logic can treat the next flag as a value. Fail fast when the value is missing.
-# Helper function to parse a single argument value -parse_value() { - if [[ "$1" != *=* ]]; then shift; fi - echo "${1#*=}" -} +# Helper: extract value from "--opt=value" or "--opt value"; error if missing +parse_value() { + local first="$1"; shift || true + if [[ "$first" == *=* ]]; then + echo "${first#*=}" + return 0 + fi + local next="${1-}" + if [[ -z "$next" || "$next" == --* || "$next" == -* ]]; then + >&2 echo "Error: Missing value for option '$first'" + exit 2 + fi + echo "$next" +}
🧹 Nitpick comments (10)
modelopt/torch/utils/network.py (1)
73-77
: Good centralization of wrapper support; consider aligning is_parallel with SUPPORTED_WRAPPERS.Nice consolidation, including optional DeepSpeed. For consistency, update is_parallel (Line 90-93) to derive from SUPPORTED_WRAPPERS so FSDP/DS are also detected.
Apply:
def is_parallel(model: nn.Module) -> bool: """Check if a PyTorch model is parallelized.""" - return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) + return type(model) in SUPPORTED_WRAPPERSAlso applies to: 79-86
modelopt/torch/opt/plugins/peft.py (1)
91-94
: Avoid KeyError if a quantizer key is missing during load.Guard the lookup so older/newer checkpoints don’t fail when module sets differ.
Apply:
- if isinstance(module, TensorQuantizer): - module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)]) + if isinstance(module, TensorQuantizer): + key = get_unwrapped_name(name, self) + state = quantizer_state_dict.get(key) + if state is not None: + module.load_state_dict(state)modelopt/torch/quantization/plugins/transformers_trainer.py (3)
206-212
: Minor: prefer is_main_process for tqdm gating.should_save often aligns with main rank but is semantically about checkpointing. Using accelerator.is_main_process is clearer.
Apply:
- for batch in tqdm(data_loader, desc="Calibrating", disable=not self.args.should_save): + for batch in tqdm( + data_loader, desc="Calibrating", disable=not self.accelerator.is_main_process + ):
268-279
: Restore FSDP state_dict_type even on exceptions during save.Use try/finally to avoid leaking FULL_STATE_DICT if save_model raises.
Apply:
- original_type = self.accelerator.state.fsdp_plugin.state_dict_type - self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") - outputs = super().save_model(*args, **kwargs) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)): - print_rank_0( - "Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the " - "model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing" - ) - self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type) + original_type = self.accelerator.state.fsdp_plugin.state_dict_type + self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + try: + outputs = super().save_model(*args, **kwargs) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)): + print_rank_0( + "Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the " + "model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing" + ) + finally: + self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
300-312
: Ensure buffer‑hiding is reverted if prepare() throws.Wrap the original prepare call in try/finally so _non_persistent_buffers_set is always restored.
Apply:
- outputs = self._original_prepare(*args, **kwargs) - - for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): - tq._non_persistent_buffers_set.clear() - tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq]) - - return outputs + try: + return self._original_prepare(*args, **kwargs) + finally: + for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): + tq._non_persistent_buffers_set.clear() + tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq])examples/llm_qat/accelerate_config/fsdp1.yaml (1)
7-7
: Enabling activation checkpointing: confirm intended trade‑offs and consistency with training args.Turning on
fsdp_activation_checkpointing: true
increases recompute and reduces memory. Ensure this is intentional for fsdp1 runs and consistent with gradient checkpointing flags in launch.sh (which only enables--gradient_checkpointing
for ddp/deepspeed). If fsdp1 also requires model‑level GC, consider wiring that similarly.examples/llm_qat/launch.sh (4)
53-55
: Fix invalid‑arg error message to print the flag itself.
${1#*=}
trims up to=
, producing odd output for space‑separated flags.- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument %s\n" "$1"
88-97
: Quote variable checks to avoid word‑splitting and unset pitfalls.-if [ -z $QUANT_CFG ]; then +if [ -z "${QUANT_CFG:-}" ]; then QUANT_ARGS="" else QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" fi -OPTIONAL_ARGS="" -if [ ! -z $MAX_STEPS ]; then +OPTIONAL_ARGS="" +if [ -n "${MAX_STEPS:-}" ]; then OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" fi
60-60
: Gate tracing behind a DEBUG flag to reduce noise.-set -x +[[ "${DEBUG:-0}" == "1" ]] && set -x
148-179
: Optional: avoid sh -c and word‑splitting by using an argv array.Safer with paths containing spaces and avoids double parsing.
-start_time=$(date +%s) -sh -c "$CMD" +start_time=$(date +%s) +eval "$CMD" echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"Or better:
# Build an array instead of a string; then: "${CMD[@]}"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
examples/llm_qat/README.md
(0 hunks)examples/llm_qat/accelerate_config/deepspeed.yaml
(1 hunks)examples/llm_qat/accelerate_config/fsdp1.yaml
(1 hunks)examples/llm_qat/convert_sharded_ckpt.py
(0 hunks)examples/llm_qat/launch.sh
(4 hunks)examples/llm_qat/llama_factory/launch_llamafactory.sh
(0 hunks)examples/llm_qat/main.py
(2 hunks)examples/llm_qat/simple_qat_train.py
(3 hunks)examples/llm_qat/utils.py
(1 hunks)modelopt/torch/opt/conversion.py
(2 hunks)modelopt/torch/opt/dynamic.py
(1 hunks)modelopt/torch/opt/plugins/peft.py
(2 hunks)modelopt/torch/quantization/calib/histogram.py
(1 hunks)modelopt/torch/quantization/conversion.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_module.py
(1 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(5 hunks)modelopt/torch/quantization/utils.py
(2 hunks)modelopt/torch/utils/network.py
(3 hunks)tests/_test_utils/examples/run_command.py
(1 hunks)tests/examples/llm_qat/test_llm_qat.py
(3 hunks)
💤 Files with no reviewable changes (3)
- examples/llm_qat/llama_factory/launch_llamafactory.sh
- examples/llm_qat/convert_sharded_ckpt.py
- examples/llm_qat/README.md
🚧 Files skipped from review as they are similar to previous changes (10)
- modelopt/torch/quantization/calib/histogram.py
- modelopt/torch/quantization/nn/modules/quant_module.py
- modelopt/torch/opt/dynamic.py
- tests/examples/llm_qat/test_llm_qat.py
- tests/_test_utils/examples/run_command.py
- modelopt/torch/quantization/conversion.py
- examples/llm_qat/accelerate_config/deepspeed.yaml
- examples/llm_qat/simple_qat_train.py
- modelopt/torch/quantization/utils.py
- examples/llm_qat/main.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/opt/conversion.py
modelopt/torch/opt/plugins/peft.py
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-16T21:46:46.318Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-212
Timestamp: 2025-09-16T21:46:46.318Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the mtq.quantize function calls the forward_loop under a no_grad context, so wrapping the forward_loop in inference_mode or no_grad is not needed.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-16T20:14:34.725Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:191-191
Timestamp: 2025-09-16T20:14:34.725Z
Learning: The TensorRT-Model-Optimizer project only supports PyTorch >= 2.6, so using the `weights_only` parameter in torch.load calls is acceptable and doesn't require backward compatibility handling.
Applied to files:
modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (3)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
unwrap_model
(430-454)
modelopt/torch/opt/plugins/peft.py (2)
modelopt/torch/quantization/utils.py (1)
get_quantizer_state_dict
(446-456)modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer
(61-1139)disable
(394-399)collect
(1123-1133)modelopt/torch/quantization/utils.py (5)
calibrate_with_adapters
(275-286)disable_lora_quantizers_in_config
(289-296)get_quantizer_state_dict
(446-456)is_quantized
(239-243)set_quantizer_state_dict
(459-466)modelopt/torch/opt/conversion.py (5)
modelopt_state
(444-486)save
(489-507)restore_from_modelopt_state
(510-567)ModeloptStateManager
(63-311)is_converted
(102-127)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (9)
examples/llm_qat/utils.py (1)
172-176
: LGTM — simple, safe perplexity augmentation.Handles scalar/tensor eval_loss correctly and returns a float. No issues.
modelopt/torch/utils/network.py (1)
599-612
: get_unwrapped_name: DP/DS prefix cleanup + FSDP name normalization — looks correct.The conditional removeprefix for DP/DS and subsequent _convert_to_wrapped_module_name cover common wrappers well.
modelopt/torch/opt/conversion.py (1)
383-384
: Wrapper handling policy is consistent across apply/save/restore.
- apply_mode/modelopt_state: force unwrapping is appropriate.
- restore: raising on wrapped models is correct to avoid ambiguous state.
Also applies to: 469-470, 591-592
modelopt/torch/opt/plugins/peft.py (1)
60-65
: Centralized quantizer snapshotting — good move.Using get_quantizer_state_dict() avoids FSDP hangs and keeps logic consistent.
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
165-170
: Early restore/save of ModelOpt state — solid.Restoring on startup and snapshotting when already quantized keeps flows resilient across backends.
171-189
: Checkpoint payload layout is sensible; barrier guard present.Filtered state plus quantizer weights is compact and sufficient. Pre‑save barrier is guarded. LGTM.
190-197
: Restore path is correct and backward‑compatible.weights_only use is fine given PyTorch ≥ 2.6, and the optional weights guard prevents crashes on older blobs.
examples/llm_qat/launch.sh (1)
110-131
: Minor: ensure FSDP override is delivered to the script, not Accelerate.After moving
$FSDP_ARGS
behindmain.py
, the distillation‑time override (--fsdp_cpu_ram_efficient_loading False
) will reach main.py as intended. No change needed beyond the earlier repositioning; call this out to double‑check downstream parsing.examples/llm_qat/accelerate_config/fsdp1.yaml (1)
13-13
: Confirm FSDP version and correct fsdp_reshard_after_forward usageFile: examples/llm_qat/accelerate_config/fsdp1.yaml (line 13)
fsdp_reshard_after_forward accepts sharding-strategy strings for fsdp_version=1 but is a boolean for fsdp_version=2 — adjust based on your fsdp_version.
If fsdp_version == 2 apply:
- fsdp_reshard_after_forward: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_sharding_strategy: FULL_SHARDIf fsdp_version == 1, no change required.
…specific unitests; (#318) Signed-off-by: realAsma <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: ? new tests, QATTrainer workflow fixes and simplification
Overview:
ModelOpt entry points now accepts all distributed wrapped models (previously we did not allow ddp/FSDP to ModelOpt entry points. We support all wrapper after this PR making ModelOpt workflows simpler).
Fixed QATTrainer FSDP2 workflow disruption and unblocked QLoRA FSDP2: Previously the ModelOpt states were restored after the weights were loaded for FSDP2. This broke ModelOpt workflow which required the ModelOpt states to be restored before weights are loaded. This workflow disruption made FSDP2 flow incompatible with QLoRA. The workflow is fixed in this PR.
This PR makes several improvements to QATTrainer and training workflow:
i. Simplified QATTrainer workflow from user-side (removed
eval_only
argument).ii. Cleaned up QATTrainer to work with various backends such as ddp, fsdp, fsdp2, DeepSpeed seamlessly.
iii. Added unit tests for llm_qat with various backends.
iv. Removed
examples/llm_qat/convert_sharded_ckpt.py
-Usage
See example/llm_qat/main.py
Testing
See the updated unit tests
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Tests
Chores